Skip to main content

mz_interchange/avro/
decode.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15    AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16    GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use serde::{Deserialize, Serialize};
26use tracing::trace;
27use uuid::Uuid;
28
29use crate::avro::ConfluentAvroResolver;
30
31/// Manages decoding of Avro-encoded bytes.
32#[derive(Debug)]
33pub struct Decoder {
34    csr_avro: ConfluentAvroResolver,
35    debug_name: String,
36    buf1: Vec<u8>,
37    row_buf: Row,
38}
39
40#[cfg(test)]
41mod tests {
42    use mz_ore::assert_err;
43    use mz_repr::{Datum, Row};
44
45    use crate::avro::Decoder;
46
47    #[mz_ore::test(tokio::test)]
48    async fn test_error_followed_by_success() {
49        let schema = r#"{
50"type": "record",
51"name": "test",
52"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
53}"#;
54        let mut decoder = Decoder::new(schema, &[], None, "Test".to_string(), false).unwrap();
55        // This is not a valid Avro blob for the given schema
56        let mut bad_bytes: &[u8] = &[0];
57        assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
58        // This is the blob that will make both ints in the value zero.
59        let mut good_bytes: &[u8] = &[0, 0];
60        // The decode should succeed with the correct value.
61        assert_eq!(
62            decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
63            Row::pack([Datum::Int32(0), Datum::Int32(0)])
64        );
65    }
66}
67
68impl Decoder {
69    /// Creates a new `Decoder`
70    ///
71    /// The provided schema is called the "reader schema", which is the schema
72    /// that we are expecting to use to decode records. The records may indicate
73    /// that they are encoded with a different schema; as long as those.
74    ///
75    /// The `reader_reference_schemas` parameter provides schemas for types that
76    /// are referenced by the reader schema but defined in separate schemas.
77    /// These should be provided in dependency order (dependencies first).
78    pub fn new(
79        reader_schema: &str,
80        reader_reference_schemas: &[String],
81        ccsr_client: Option<mz_ccsr::Client>,
82        debug_name: String,
83        confluent_wire_format: bool,
84    ) -> anyhow::Result<Decoder> {
85        let csr_avro = ConfluentAvroResolver::new(
86            reader_schema,
87            reader_reference_schemas,
88            ccsr_client,
89            confluent_wire_format,
90        )?;
91
92        Ok(Decoder {
93            csr_avro,
94            debug_name,
95            buf1: vec![],
96            row_buf: Row::default(),
97        })
98    }
99
100    /// Decodes Avro-encoded `bytes` into a `Row`.
101    pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
102        // Clear out any bytes that might be left over from
103        // an earlier run. This can happen if the
104        // `dsr.deserialize` call returns an error,
105        // causing us to return early.
106        let mut packer = self.row_buf.packer();
107        // The outer Result describes transient errors so use ? here to propagate
108        let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
109            Ok(ok) => ok,
110            Err(err) => return Ok(Err(err)),
111        };
112        *bytes = bytes2;
113        let dec = AvroFlatDecoder {
114            packer: &mut packer,
115            buf: &mut self.buf1,
116            is_top: true,
117        };
118        let dsr = GeneralDeserializer {
119            schema: resolved_schema.top_node(),
120        };
121        let result = dsr
122            .deserialize(bytes, dec)
123            .with_context(|| {
124                format!(
125                    "unable to decode row {}",
126                    match csr_schema_id {
127                        Some(id) => format!("(Avro schema id = {:?})", id),
128                        None => "".to_string(),
129                    }
130                )
131            })
132            .map(|_| self.row_buf.clone());
133        if result.is_ok() {
134            trace!(
135                "[customer-data] Decoded row {:?} in {}",
136                self.row_buf, self.debug_name
137            );
138        }
139        Ok(result)
140    }
141}
142
143#[derive(Debug)]
144pub struct AvroFlatDecoder<'a, 'row> {
145    pub packer: &'a mut RowPacker<'row>,
146    pub buf: &'a mut Vec<u8>,
147    pub is_top: bool,
148}
149
150impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
151    type Out = ();
152    #[inline]
153    fn record<R: AvroRead, A: AvroRecordAccess<R>>(
154        self,
155        a: &mut A,
156    ) -> Result<Self::Out, AvroError> {
157        let mut str_buf = std::mem::take(self.buf);
158        let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
159            let mut expected = 0;
160            let mut stash = vec![];
161            // The idea here is that if the deserializer gives us fields in the order we're expecting,
162            // we can decode them directly into the row.
163            // If not, we need to decode them into a Value (the old, slow decoding path) and stash them,
164            // so that we can put everything in the right order at the end.
165            //
166            // TODO(btv) - this is pretty bad, as a misordering at the top of the schema graph will
167            // cause the _entire_ chunk under it to be decoded in the slow way!
168            // Maybe instead, we should decode to separate sub-Rows and then add an API
169            // to Row that just copies in the bytes from another one.
170            while let Some((_name, idx, f)) = a.next_field()? {
171                if idx == expected {
172                    expected += 1;
173                    f.decode_field(AvroFlatDecoder {
174                        packer: rp,
175                        buf: &mut str_buf,
176                        is_top: false,
177                    })?;
178                } else {
179                    let val = f.decode_field(ValueDecoder)?;
180                    stash.push((idx, val));
181                }
182            }
183            stash.sort_by_key(|(idx, _val)| *idx);
184            for (idx, val) in stash {
185                assert!(idx == expected);
186                expected += 1;
187                let dec = AvroFlatDecoder {
188                    packer: rp,
189                    buf: &mut str_buf,
190                    is_top: false,
191                };
192                give_value(dec, &val)?;
193            }
194            Ok(())
195        };
196        if self.is_top {
197            pack_record(self.packer)?;
198        } else {
199            self.packer.push_list_with(pack_record)?;
200        }
201        *self.buf = str_buf;
202        Ok(())
203    }
204    #[inline]
205    fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
206        self,
207        idx: usize,
208        n_variants: usize,
209        null_variant: Option<usize>,
210        deserializer: D,
211        reader: &'b mut R,
212    ) -> Result<Self::Out, AvroError> {
213        if null_variant == Some(idx) {
214            for _ in 0..n_variants - 1 {
215                self.packer.push(Datum::Null)
216            }
217        } else {
218            let mut deserializer = Some(deserializer);
219            for i in 0..n_variants {
220                let dec = AvroFlatDecoder {
221                    packer: self.packer,
222                    buf: self.buf,
223                    is_top: false,
224                };
225                if null_variant != Some(i) {
226                    if i == idx {
227                        deserializer.take().unwrap().deserialize(reader, dec)?;
228                    } else {
229                        self.packer.push(Datum::Null)
230                    }
231                }
232            }
233        }
234        Ok(())
235    }
236
237    #[inline]
238    fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
239        self.packer.push(Datum::String(symbol));
240        Ok(())
241    }
242    #[inline]
243    fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
244        match scalar {
245            mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
246            mz_avro::types::Scalar::Boolean(val) => {
247                if val {
248                    self.packer.push(Datum::True)
249                } else {
250                    self.packer.push(Datum::False)
251                }
252            }
253            mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
254            mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
255            mz_avro::types::Scalar::Float(val) => {
256                self.packer.push(Datum::Float32(OrderedFloat(val)))
257            }
258            mz_avro::types::Scalar::Double(val) => {
259                self.packer.push(Datum::Float64(OrderedFloat(val)))
260            }
261            mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
262                Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
263            )),
264            mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
265                CheckedTimestamp::from_timestamplike(val)
266                    .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
267            )),
268        }
269        Ok(())
270    }
271
272    #[inline]
273    fn decimal<'b, R: AvroRead>(
274        self,
275        _precision: usize,
276        scale: usize,
277        r: ValueOrReader<'b, &'b [u8], R>,
278    ) -> Result<Self::Out, AvroError> {
279        let mut buf = match r {
280            ValueOrReader::Value(val) => val.to_vec(),
281            ValueOrReader::Reader { len, r } => {
282                self.buf.resize_with(len, Default::default);
283                r.read_exact(self.buf)?;
284                let v = self.buf.clone();
285                v
286            }
287        };
288
289        let scale = u8::try_from(scale).map_err(|_| {
290            DecodeError::Custom(format!(
291                "Error decoding decimal: scale must fit within u8, but got scale {}",
292                scale,
293            ))
294        })?;
295
296        let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
297            .map_err(|e| e.to_string_with_causes())
298            .map_err(DecodeError::Custom)?;
299
300        if n.is_special()
301            || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
302        {
303            return Err(AvroError::Decode(DecodeError::Custom(format!(
304                "Error decoding numeric: exceeds maximum precision {}",
305                numeric::NUMERIC_DATUM_MAX_PRECISION
306            ))));
307        }
308
309        self.packer.push(Datum::from(n));
310
311        Ok(())
312    }
313
314    #[inline]
315    fn bytes<'b, R: AvroRead>(
316        self,
317        r: ValueOrReader<'b, &'b [u8], R>,
318    ) -> Result<Self::Out, AvroError> {
319        let buf = match r {
320            ValueOrReader::Value(val) => val,
321            ValueOrReader::Reader { len, r } => {
322                self.buf.resize_with(len, Default::default);
323                r.read_exact(self.buf)?;
324                self.buf
325            }
326        };
327        self.packer.push(Datum::Bytes(buf));
328        Ok(())
329    }
330    #[inline]
331    fn string<'b, R: AvroRead>(
332        self,
333        r: ValueOrReader<'b, &'b str, R>,
334    ) -> Result<Self::Out, AvroError> {
335        let s = match r {
336            ValueOrReader::Value(val) => val,
337            ValueOrReader::Reader { len, r } => {
338                // TODO - this copy is unnecessary,
339                // we should special case to just look at the bytes
340                // directly when r is &[u8].
341                // It probably doesn't make a huge difference though.
342                self.buf.resize_with(len, Default::default);
343                r.read_exact(self.buf)?;
344                std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
345            }
346        };
347        self.packer.push(Datum::String(s));
348        Ok(())
349    }
350    #[inline]
351    fn json<'b, R: AvroRead>(
352        self,
353        r: ValueOrReader<'b, &'b serde_json::Value, R>,
354    ) -> Result<Self::Out, AvroError> {
355        match r {
356            ValueOrReader::Value(val) => {
357                JsonbPacker::new(self.packer)
358                    .pack_serde_json(val.clone())
359                    .map_err(|e| {
360                        // Technically, these are not the original bytes;
361                        // they've gone through a deserialize-serialize
362                        // round trip. Hopefully they will be close enough to still
363                        // be useful for debugging.
364                        let bytes = val.to_string().into_bytes();
365
366                        DecodeError::BadJson {
367                            category: e.classify(),
368                            bytes,
369                        }
370                    })?;
371            }
372            ValueOrReader::Reader { len, r } => {
373                self.buf.resize_with(len, Default::default);
374                r.read_exact(self.buf)?;
375                JsonbPacker::new(self.packer)
376                    .pack_slice(self.buf)
377                    .map_err(|e| DecodeError::BadJson {
378                        category: e.classify(),
379                        bytes: self.buf.to_owned(),
380                    })?;
381            }
382        }
383        Ok(())
384    }
385    #[inline]
386    fn uuid<'b, R: AvroRead>(
387        self,
388        r: ValueOrReader<'b, &'b [u8], R>,
389    ) -> Result<Self::Out, AvroError> {
390        let buf = match r {
391            ValueOrReader::Value(val) => val,
392            ValueOrReader::Reader { len, r } => {
393                self.buf.resize_with(len, Default::default);
394                r.read_exact(self.buf)?;
395                self.buf
396            }
397        };
398        let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
399        self.packer.push(Datum::Uuid(
400            Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
401        ));
402        Ok(())
403    }
404    #[inline]
405    fn fixed<'b, R: AvroRead>(
406        self,
407        r: ValueOrReader<'b, &'b [u8], R>,
408    ) -> Result<Self::Out, AvroError> {
409        self.bytes(r)
410    }
411    #[inline]
412    fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
413        self.is_top = false;
414        let mut str_buf = std::mem::take(self.buf);
415        self.packer.push_list_with(|rp| -> Result<(), AvroError> {
416            loop {
417                let next = AvroFlatDecoder {
418                    packer: rp,
419                    buf: &mut str_buf,
420                    is_top: false,
421                };
422                if a.decode_next(next)?.is_none() {
423                    break;
424                }
425            }
426            Ok(())
427        })?;
428        *self.buf = str_buf;
429        Ok(())
430    }
431    #[inline]
432    fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
433        // Map (key, value) pairs need to be unique and ordered.
434        let mut map = BTreeMap::new();
435        while let Some((name, f)) = a.next_entry()? {
436            map.insert(name, f.decode_field(ValueDecoder)?);
437        }
438        self.packer
439            .push_dict_with(|packer| -> Result<(), AvroError> {
440                for (key, val) in map {
441                    packer.push(Datum::String(key.as_str()));
442                    give_value(
443                        AvroFlatDecoder {
444                            packer,
445                            buf: &mut vec![],
446                            is_top: false,
447                        },
448                        &val,
449                    )?;
450                }
451                Ok(())
452            })?;
453
454        Ok(())
455    }
456}
457
458#[derive(Clone, Debug, Serialize, Deserialize)]
459pub struct DiffPair<T> {
460    pub before: Option<T>,
461    pub after: Option<T>,
462}