mz_interchange/avro/
decode.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15    AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16    GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use serde::{Deserialize, Serialize};
26use tracing::trace;
27use uuid::Uuid;
28
29use crate::avro::ConfluentAvroResolver;
30
31/// Manages decoding of Avro-encoded bytes.
32#[derive(Debug)]
33pub struct Decoder {
34    csr_avro: ConfluentAvroResolver,
35    debug_name: String,
36    buf1: Vec<u8>,
37    row_buf: Row,
38}
39
40#[cfg(test)]
41mod tests {
42    use mz_ore::assert_err;
43    use mz_repr::{Datum, Row};
44
45    use crate::avro::Decoder;
46
47    #[mz_ore::test(tokio::test)]
48    async fn test_error_followed_by_success() {
49        let schema = r#"{
50"type": "record",
51"name": "test",
52"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
53}"#;
54        let mut decoder = Decoder::new(schema, None, "Test".to_string(), false).unwrap();
55        // This is not a valid Avro blob for the given schema
56        let mut bad_bytes: &[u8] = &[0];
57        assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
58        // This is the blob that will make both ints in the value zero.
59        let mut good_bytes: &[u8] = &[0, 0];
60        // The decode should succeed with the correct value.
61        assert_eq!(
62            decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
63            Row::pack([Datum::Int32(0), Datum::Int32(0)])
64        );
65    }
66}
67
68impl Decoder {
69    /// Creates a new `Decoder`
70    ///
71    /// The provided schema is called the "reader schema", which is the schema
72    /// that we are expecting to use to decode records. The records may indicate
73    /// that they are encoded with a different schema; as long as those.
74    pub fn new(
75        reader_schema: &str,
76        ccsr_client: Option<mz_ccsr::Client>,
77        debug_name: String,
78        confluent_wire_format: bool,
79    ) -> anyhow::Result<Decoder> {
80        let csr_avro =
81            ConfluentAvroResolver::new(reader_schema, ccsr_client, confluent_wire_format)?;
82
83        Ok(Decoder {
84            csr_avro,
85            debug_name,
86            buf1: vec![],
87            row_buf: Row::default(),
88        })
89    }
90
91    /// Decodes Avro-encoded `bytes` into a `Row`.
92    pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
93        // Clear out any bytes that might be left over from
94        // an earlier run. This can happen if the
95        // `dsr.deserialize` call returns an error,
96        // causing us to return early.
97        let mut packer = self.row_buf.packer();
98        // The outer Result describes transient errors so use ? here to propagate
99        let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
100            Ok(ok) => ok,
101            Err(err) => return Ok(Err(err)),
102        };
103        *bytes = bytes2;
104        let dec = AvroFlatDecoder {
105            packer: &mut packer,
106            buf: &mut self.buf1,
107            is_top: true,
108        };
109        let dsr = GeneralDeserializer {
110            schema: resolved_schema.top_node(),
111        };
112        let result = dsr
113            .deserialize(bytes, dec)
114            .with_context(|| {
115                format!(
116                    "unable to decode row {}",
117                    match csr_schema_id {
118                        Some(id) => format!("(Avro schema id = {:?})", id),
119                        None => "".to_string(),
120                    }
121                )
122            })
123            .map(|_| self.row_buf.clone());
124        if result.is_ok() {
125            trace!(
126                "[customer-data] Decoded row {:?} in {}",
127                self.row_buf, self.debug_name
128            );
129        }
130        Ok(result)
131    }
132}
133
134#[derive(Debug)]
135pub struct AvroFlatDecoder<'a, 'row> {
136    pub packer: &'a mut RowPacker<'row>,
137    pub buf: &'a mut Vec<u8>,
138    pub is_top: bool,
139}
140
141impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
142    type Out = ();
143    #[inline]
144    fn record<R: AvroRead, A: AvroRecordAccess<R>>(
145        self,
146        a: &mut A,
147    ) -> Result<Self::Out, AvroError> {
148        let mut str_buf = std::mem::take(self.buf);
149        let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
150            let mut expected = 0;
151            let mut stash = vec![];
152            // The idea here is that if the deserializer gives us fields in the order we're expecting,
153            // we can decode them directly into the row.
154            // If not, we need to decode them into a Value (the old, slow decoding path) and stash them,
155            // so that we can put everything in the right order at the end.
156            //
157            // TODO(btv) - this is pretty bad, as a misordering at the top of the schema graph will
158            // cause the _entire_ chunk under it to be decoded in the slow way!
159            // Maybe instead, we should decode to separate sub-Rows and then add an API
160            // to Row that just copies in the bytes from another one.
161            while let Some((_name, idx, f)) = a.next_field()? {
162                if idx == expected {
163                    expected += 1;
164                    f.decode_field(AvroFlatDecoder {
165                        packer: rp,
166                        buf: &mut str_buf,
167                        is_top: false,
168                    })?;
169                } else {
170                    let val = f.decode_field(ValueDecoder)?;
171                    stash.push((idx, val));
172                }
173            }
174            stash.sort_by_key(|(idx, _val)| *idx);
175            for (idx, val) in stash {
176                assert!(idx == expected);
177                expected += 1;
178                let dec = AvroFlatDecoder {
179                    packer: rp,
180                    buf: &mut str_buf,
181                    is_top: false,
182                };
183                give_value(dec, &val)?;
184            }
185            Ok(())
186        };
187        if self.is_top {
188            pack_record(self.packer)?;
189        } else {
190            self.packer.push_list_with(pack_record)?;
191        }
192        *self.buf = str_buf;
193        Ok(())
194    }
195    #[inline]
196    fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
197        self,
198        idx: usize,
199        n_variants: usize,
200        null_variant: Option<usize>,
201        deserializer: D,
202        reader: &'b mut R,
203    ) -> Result<Self::Out, AvroError> {
204        if null_variant == Some(idx) {
205            for _ in 0..n_variants - 1 {
206                self.packer.push(Datum::Null)
207            }
208        } else {
209            let mut deserializer = Some(deserializer);
210            for i in 0..n_variants {
211                let dec = AvroFlatDecoder {
212                    packer: self.packer,
213                    buf: self.buf,
214                    is_top: false,
215                };
216                if null_variant != Some(i) {
217                    if i == idx {
218                        deserializer.take().unwrap().deserialize(reader, dec)?;
219                    } else {
220                        self.packer.push(Datum::Null)
221                    }
222                }
223            }
224        }
225        Ok(())
226    }
227
228    #[inline]
229    fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
230        self.packer.push(Datum::String(symbol));
231        Ok(())
232    }
233    #[inline]
234    fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
235        match scalar {
236            mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
237            mz_avro::types::Scalar::Boolean(val) => {
238                if val {
239                    self.packer.push(Datum::True)
240                } else {
241                    self.packer.push(Datum::False)
242                }
243            }
244            mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
245            mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
246            mz_avro::types::Scalar::Float(val) => {
247                self.packer.push(Datum::Float32(OrderedFloat(val)))
248            }
249            mz_avro::types::Scalar::Double(val) => {
250                self.packer.push(Datum::Float64(OrderedFloat(val)))
251            }
252            mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
253                Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
254            )),
255            mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
256                CheckedTimestamp::from_timestamplike(val)
257                    .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
258            )),
259        }
260        Ok(())
261    }
262
263    #[inline]
264    fn decimal<'b, R: AvroRead>(
265        self,
266        _precision: usize,
267        scale: usize,
268        r: ValueOrReader<'b, &'b [u8], R>,
269    ) -> Result<Self::Out, AvroError> {
270        let mut buf = match r {
271            ValueOrReader::Value(val) => val.to_vec(),
272            ValueOrReader::Reader { len, r } => {
273                self.buf.resize_with(len, Default::default);
274                r.read_exact(self.buf)?;
275                let v = self.buf.clone();
276                v
277            }
278        };
279
280        let scale = u8::try_from(scale).map_err(|_| {
281            DecodeError::Custom(format!(
282                "Error decoding decimal: scale must fit within u8, but got scale {}",
283                scale,
284            ))
285        })?;
286
287        let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
288            .map_err(|e| e.to_string_with_causes())
289            .map_err(DecodeError::Custom)?;
290
291        if n.is_special()
292            || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
293        {
294            return Err(AvroError::Decode(DecodeError::Custom(format!(
295                "Error decoding numeric: exceeds maximum precision {}",
296                numeric::NUMERIC_DATUM_MAX_PRECISION
297            ))));
298        }
299
300        self.packer.push(Datum::from(n));
301
302        Ok(())
303    }
304
305    #[inline]
306    fn bytes<'b, R: AvroRead>(
307        self,
308        r: ValueOrReader<'b, &'b [u8], R>,
309    ) -> Result<Self::Out, AvroError> {
310        let buf = match r {
311            ValueOrReader::Value(val) => val,
312            ValueOrReader::Reader { len, r } => {
313                self.buf.resize_with(len, Default::default);
314                r.read_exact(self.buf)?;
315                self.buf
316            }
317        };
318        self.packer.push(Datum::Bytes(buf));
319        Ok(())
320    }
321    #[inline]
322    fn string<'b, R: AvroRead>(
323        self,
324        r: ValueOrReader<'b, &'b str, R>,
325    ) -> Result<Self::Out, AvroError> {
326        let s = match r {
327            ValueOrReader::Value(val) => val,
328            ValueOrReader::Reader { len, r } => {
329                // TODO - this copy is unnecessary,
330                // we should special case to just look at the bytes
331                // directly when r is &[u8].
332                // It probably doesn't make a huge difference though.
333                self.buf.resize_with(len, Default::default);
334                r.read_exact(self.buf)?;
335                std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
336            }
337        };
338        self.packer.push(Datum::String(s));
339        Ok(())
340    }
341    #[inline]
342    fn json<'b, R: AvroRead>(
343        self,
344        r: ValueOrReader<'b, &'b serde_json::Value, R>,
345    ) -> Result<Self::Out, AvroError> {
346        match r {
347            ValueOrReader::Value(val) => {
348                JsonbPacker::new(self.packer)
349                    .pack_serde_json(val.clone())
350                    .map_err(|e| {
351                        // Technically, these are not the original bytes;
352                        // they've gone through a deserialize-serialize
353                        // round trip. Hopefully they will be close enough to still
354                        // be useful for debugging.
355                        let bytes = val.to_string().into_bytes();
356
357                        DecodeError::BadJson {
358                            category: e.classify(),
359                            bytes,
360                        }
361                    })?;
362            }
363            ValueOrReader::Reader { len, r } => {
364                self.buf.resize_with(len, Default::default);
365                r.read_exact(self.buf)?;
366                JsonbPacker::new(self.packer)
367                    .pack_slice(self.buf)
368                    .map_err(|e| DecodeError::BadJson {
369                        category: e.classify(),
370                        bytes: self.buf.to_owned(),
371                    })?;
372            }
373        }
374        Ok(())
375    }
376    #[inline]
377    fn uuid<'b, R: AvroRead>(
378        self,
379        r: ValueOrReader<'b, &'b [u8], R>,
380    ) -> Result<Self::Out, AvroError> {
381        let buf = match r {
382            ValueOrReader::Value(val) => val,
383            ValueOrReader::Reader { len, r } => {
384                self.buf.resize_with(len, Default::default);
385                r.read_exact(self.buf)?;
386                self.buf
387            }
388        };
389        let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
390        self.packer.push(Datum::Uuid(
391            Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
392        ));
393        Ok(())
394    }
395    #[inline]
396    fn fixed<'b, R: AvroRead>(
397        self,
398        r: ValueOrReader<'b, &'b [u8], R>,
399    ) -> Result<Self::Out, AvroError> {
400        self.bytes(r)
401    }
402    #[inline]
403    fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
404        self.is_top = false;
405        let mut str_buf = std::mem::take(self.buf);
406        self.packer.push_list_with(|rp| -> Result<(), AvroError> {
407            loop {
408                let next = AvroFlatDecoder {
409                    packer: rp,
410                    buf: &mut str_buf,
411                    is_top: false,
412                };
413                if a.decode_next(next)?.is_none() {
414                    break;
415                }
416            }
417            Ok(())
418        })?;
419        *self.buf = str_buf;
420        Ok(())
421    }
422    #[inline]
423    fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
424        // Map (key, value) pairs need to be unique and ordered.
425        let mut map = BTreeMap::new();
426        while let Some((name, f)) = a.next_entry()? {
427            map.insert(name, f.decode_field(ValueDecoder)?);
428        }
429        self.packer
430            .push_dict_with(|packer| -> Result<(), AvroError> {
431                for (key, val) in map {
432                    packer.push(Datum::String(key.as_str()));
433                    give_value(
434                        AvroFlatDecoder {
435                            packer,
436                            buf: &mut vec![],
437                            is_top: false,
438                        },
439                        &val,
440                    )?;
441                }
442                Ok(())
443            })?;
444
445        Ok(())
446    }
447}
448
449#[derive(Clone, Debug, Serialize, Deserialize)]
450pub struct DiffPair<T> {
451    pub before: Option<T>,
452    pub after: Option<T>,
453}