mz_interchange/avro/
decode.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::io::Read;
13use std::rc::Rc;
14
15use anyhow::{Context, Error};
16use mz_avro::error::{DecodeError, Error as AvroError};
17use mz_avro::{
18    AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
19    GeneralDeserializer, StatefulAvroDecodable, ValueDecoder, ValueOrReader, define_unexpected,
20    give_value,
21};
22use mz_ore::error::ErrorExt;
23use mz_repr::adt::date::Date;
24use mz_repr::adt::jsonb::JsonbPacker;
25use mz_repr::adt::numeric;
26use mz_repr::adt::timestamp::CheckedTimestamp;
27use mz_repr::{Datum, Row, RowPacker};
28use ordered_float::OrderedFloat;
29use tracing::trace;
30use uuid::Uuid;
31
32use crate::avro::ConfluentAvroResolver;
33
34/// Manages decoding of Avro-encoded bytes.
35#[derive(Debug)]
36pub struct Decoder {
37    csr_avro: ConfluentAvroResolver,
38    debug_name: String,
39    buf1: Vec<u8>,
40    row_buf: Row,
41}
42
43#[cfg(test)]
44mod tests {
45    use mz_ore::assert_err;
46    use mz_repr::{Datum, Row};
47
48    use crate::avro::Decoder;
49
50    #[mz_ore::test(tokio::test)]
51    async fn test_error_followed_by_success() {
52        let schema = r#"{
53"type": "record",
54"name": "test",
55"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
56}"#;
57        let mut decoder = Decoder::new(schema, None, "Test".to_string(), false).unwrap();
58        // This is not a valid Avro blob for the given schema
59        let mut bad_bytes: &[u8] = &[0];
60        assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
61        // This is the blob that will make both ints in the value zero.
62        let mut good_bytes: &[u8] = &[0, 0];
63        // The decode should succeed with the correct value.
64        assert_eq!(
65            decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
66            Row::pack([Datum::Int32(0), Datum::Int32(0)])
67        );
68    }
69}
70
71impl Decoder {
72    /// Creates a new `Decoder`
73    ///
74    /// The provided schema is called the "reader schema", which is the schema
75    /// that we are expecting to use to decode records. The records may indicate
76    /// that they are encoded with a different schema; as long as those.
77    pub fn new(
78        reader_schema: &str,
79        ccsr_client: Option<mz_ccsr::Client>,
80        debug_name: String,
81        confluent_wire_format: bool,
82    ) -> anyhow::Result<Decoder> {
83        let csr_avro =
84            ConfluentAvroResolver::new(reader_schema, ccsr_client, confluent_wire_format)?;
85
86        Ok(Decoder {
87            csr_avro,
88            debug_name,
89            buf1: vec![],
90            row_buf: Row::default(),
91        })
92    }
93
94    /// Decodes Avro-encoded `bytes` into a `Row`.
95    pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
96        // Clear out any bytes that might be left over from
97        // an earlier run. This can happen if the
98        // `dsr.deserialize` call returns an error,
99        // causing us to return early.
100        let mut packer = self.row_buf.packer();
101        // The outer Result describes transient errors so use ? here to propagate
102        let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
103            Ok(ok) => ok,
104            Err(err) => return Ok(Err(err)),
105        };
106        *bytes = bytes2;
107        let dec = AvroFlatDecoder {
108            packer: &mut packer,
109            buf: &mut self.buf1,
110            is_top: true,
111        };
112        let dsr = GeneralDeserializer {
113            schema: resolved_schema.top_node(),
114        };
115        let result = dsr
116            .deserialize(bytes, dec)
117            .with_context(|| {
118                format!(
119                    "unable to decode row {}",
120                    match csr_schema_id {
121                        Some(id) => format!("(Avro schema id = {:?})", id),
122                        None => "".to_string(),
123                    }
124                )
125            })
126            .map(|_| self.row_buf.clone());
127        if result.is_ok() {
128            trace!(
129                "[customer-data] Decoded row {:?} in {}",
130                self.row_buf, self.debug_name
131            );
132        }
133        Ok(result)
134    }
135}
136
137pub struct AvroStringDecoder<'a> {
138    pub buf: &'a mut Vec<u8>,
139}
140
141impl<'a> AvroDecode for AvroStringDecoder<'a> {
142    type Out = ();
143    fn string<'b, R: AvroRead>(
144        self,
145        r: ValueOrReader<'b, &'b str, R>,
146    ) -> Result<Self::Out, AvroError> {
147        match r {
148            ValueOrReader::Value(val) => {
149                self.buf.resize_with(val.len(), Default::default);
150                val.as_bytes().read_exact(self.buf)?;
151            }
152            ValueOrReader::Reader { len, r } => {
153                self.buf.resize_with(len, Default::default);
154                r.read_exact(self.buf)?;
155            }
156        }
157        Ok(())
158    }
159    define_unexpected! {
160        record, union_branch, array, map, enum_variant, scalar, decimal, bytes, json, uuid, fixed
161    }
162}
163
164// TODO(parkmycar): Should we just delete this?
165#[allow(dead_code)]
166pub(super) struct OptionalRecordDecoder<'a, 'row> {
167    pub packer: &'a mut RowPacker<'row>,
168    pub buf: &'a mut Vec<u8>,
169}
170
171impl<'a, 'row> AvroDecode for OptionalRecordDecoder<'a, 'row> {
172    type Out = bool;
173    fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
174        self,
175        idx: usize,
176        _n_variants: usize,
177        null_variant: Option<usize>,
178        deserializer: D,
179        reader: &'b mut R,
180    ) -> Result<Self::Out, AvroError> {
181        if Some(idx) == null_variant {
182            // we are done, the row is null!
183            Ok(false)
184        } else {
185            let d = AvroFlatDecoder {
186                packer: self.packer,
187                buf: self.buf,
188                is_top: false,
189            };
190            deserializer.deserialize(reader, d)?;
191            Ok(true)
192        }
193    }
194    define_unexpected! {
195        record, array, map, enum_variant, scalar, decimal, bytes, string, json, uuid, fixed
196    }
197}
198
199pub(super) struct RowDecoder {
200    state: (Rc<RefCell<Row>>, Rc<RefCell<Vec<u8>>>),
201}
202
203impl AvroDecode for RowDecoder {
204    type Out = RowWrapper;
205    fn record<R: AvroRead, A: AvroRecordAccess<R>>(
206        self,
207        a: &mut A,
208    ) -> Result<Self::Out, AvroError> {
209        let mut row_borrow = self.state.0.borrow_mut();
210        let mut buf_borrow = self.state.1.borrow_mut();
211        let mut packer = row_borrow.packer();
212        let inner = AvroFlatDecoder {
213            packer: &mut packer,
214            buf: &mut buf_borrow,
215            is_top: true,
216        };
217        inner.record(a)?;
218        Ok(RowWrapper(row_borrow.clone()))
219    }
220    define_unexpected! {
221        union_branch, array, map, enum_variant, scalar, decimal, bytes, string, json, uuid, fixed
222    }
223}
224
225// Get around orphan rule
226#[derive(Debug)]
227pub(super) struct RowWrapper(#[allow(dead_code)] pub Row);
228
229impl StatefulAvroDecodable for RowWrapper {
230    type Decoder = RowDecoder;
231    // TODO - can we make this some sort of &'a mut (Row, Vec<u8>) without
232    // running into lifetime crap?
233    type State = (Rc<RefCell<Row>>, Rc<RefCell<Vec<u8>>>);
234
235    fn new_decoder(state: Self::State) -> Self::Decoder {
236        Self::Decoder { state }
237    }
238}
239
240#[derive(Debug)]
241pub struct AvroFlatDecoder<'a, 'row> {
242    pub packer: &'a mut RowPacker<'row>,
243    pub buf: &'a mut Vec<u8>,
244    pub is_top: bool,
245}
246
247impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
248    type Out = ();
249    #[inline]
250    fn record<R: AvroRead, A: AvroRecordAccess<R>>(
251        self,
252        a: &mut A,
253    ) -> Result<Self::Out, AvroError> {
254        let mut str_buf = std::mem::take(self.buf);
255        let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
256            let mut expected = 0;
257            let mut stash = vec![];
258            // The idea here is that if the deserializer gives us fields in the order we're expecting,
259            // we can decode them directly into the row.
260            // If not, we need to decode them into a Value (the old, slow decoding path) and stash them,
261            // so that we can put everything in the right order at the end.
262            //
263            // TODO(btv) - this is pretty bad, as a misordering at the top of the schema graph will
264            // cause the _entire_ chunk under it to be decoded in the slow way!
265            // Maybe instead, we should decode to separate sub-Rows and then add an API
266            // to Row that just copies in the bytes from another one.
267            while let Some((_name, idx, f)) = a.next_field()? {
268                if idx == expected {
269                    expected += 1;
270                    f.decode_field(AvroFlatDecoder {
271                        packer: rp,
272                        buf: &mut str_buf,
273                        is_top: false,
274                    })?;
275                } else {
276                    let val = f.decode_field(ValueDecoder)?;
277                    stash.push((idx, val));
278                }
279            }
280            stash.sort_by_key(|(idx, _val)| *idx);
281            for (idx, val) in stash {
282                assert!(idx == expected);
283                expected += 1;
284                let dec = AvroFlatDecoder {
285                    packer: rp,
286                    buf: &mut str_buf,
287                    is_top: false,
288                };
289                give_value(dec, &val)?;
290            }
291            Ok(())
292        };
293        if self.is_top {
294            pack_record(self.packer)?;
295        } else {
296            self.packer.push_list_with(pack_record)?;
297        }
298        *self.buf = str_buf;
299        Ok(())
300    }
301    #[inline]
302    fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
303        self,
304        idx: usize,
305        n_variants: usize,
306        null_variant: Option<usize>,
307        deserializer: D,
308        reader: &'b mut R,
309    ) -> Result<Self::Out, AvroError> {
310        if null_variant == Some(idx) {
311            for _ in 0..n_variants - 1 {
312                self.packer.push(Datum::Null)
313            }
314        } else {
315            let mut deserializer = Some(deserializer);
316            for i in 0..n_variants {
317                let dec = AvroFlatDecoder {
318                    packer: self.packer,
319                    buf: self.buf,
320                    is_top: false,
321                };
322                if null_variant != Some(i) {
323                    if i == idx {
324                        deserializer.take().unwrap().deserialize(reader, dec)?;
325                    } else {
326                        self.packer.push(Datum::Null)
327                    }
328                }
329            }
330        }
331        Ok(())
332    }
333
334    #[inline]
335    fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
336        self.packer.push(Datum::String(symbol));
337        Ok(())
338    }
339    #[inline]
340    fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
341        match scalar {
342            mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
343            mz_avro::types::Scalar::Boolean(val) => {
344                if val {
345                    self.packer.push(Datum::True)
346                } else {
347                    self.packer.push(Datum::False)
348                }
349            }
350            mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
351            mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
352            mz_avro::types::Scalar::Float(val) => {
353                self.packer.push(Datum::Float32(OrderedFloat(val)))
354            }
355            mz_avro::types::Scalar::Double(val) => {
356                self.packer.push(Datum::Float64(OrderedFloat(val)))
357            }
358            mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
359                Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
360            )),
361            mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
362                CheckedTimestamp::from_timestamplike(val)
363                    .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
364            )),
365        }
366        Ok(())
367    }
368
369    #[inline]
370    fn decimal<'b, R: AvroRead>(
371        self,
372        _precision: usize,
373        scale: usize,
374        r: ValueOrReader<'b, &'b [u8], R>,
375    ) -> Result<Self::Out, AvroError> {
376        let mut buf = match r {
377            ValueOrReader::Value(val) => val.to_vec(),
378            ValueOrReader::Reader { len, r } => {
379                self.buf.resize_with(len, Default::default);
380                r.read_exact(self.buf)?;
381                let v = self.buf.clone();
382                v
383            }
384        };
385
386        let scale = u8::try_from(scale).map_err(|_| {
387            DecodeError::Custom(format!(
388                "Error decoding decimal: scale must fit within u8, but got scale {}",
389                scale,
390            ))
391        })?;
392
393        let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
394            .map_err(|e| e.to_string_with_causes())
395            .map_err(DecodeError::Custom)?;
396
397        if n.is_special()
398            || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
399        {
400            return Err(AvroError::Decode(DecodeError::Custom(format!(
401                "Error decoding numeric: exceeds maximum precision {}",
402                numeric::NUMERIC_DATUM_MAX_PRECISION
403            ))));
404        }
405
406        self.packer.push(Datum::from(n));
407
408        Ok(())
409    }
410
411    #[inline]
412    fn bytes<'b, R: AvroRead>(
413        self,
414        r: ValueOrReader<'b, &'b [u8], R>,
415    ) -> Result<Self::Out, AvroError> {
416        let buf = match r {
417            ValueOrReader::Value(val) => val,
418            ValueOrReader::Reader { len, r } => {
419                self.buf.resize_with(len, Default::default);
420                r.read_exact(self.buf)?;
421                self.buf
422            }
423        };
424        self.packer.push(Datum::Bytes(buf));
425        Ok(())
426    }
427    #[inline]
428    fn string<'b, R: AvroRead>(
429        self,
430        r: ValueOrReader<'b, &'b str, R>,
431    ) -> Result<Self::Out, AvroError> {
432        let s = match r {
433            ValueOrReader::Value(val) => val,
434            ValueOrReader::Reader { len, r } => {
435                // TODO - this copy is unnecessary,
436                // we should special case to just look at the bytes
437                // directly when r is &[u8].
438                // It probably doesn't make a huge difference though.
439                self.buf.resize_with(len, Default::default);
440                r.read_exact(self.buf)?;
441                std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
442            }
443        };
444        self.packer.push(Datum::String(s));
445        Ok(())
446    }
447    #[inline]
448    fn json<'b, R: AvroRead>(
449        self,
450        r: ValueOrReader<'b, &'b serde_json::Value, R>,
451    ) -> Result<Self::Out, AvroError> {
452        match r {
453            ValueOrReader::Value(val) => {
454                JsonbPacker::new(self.packer)
455                    .pack_serde_json(val.clone())
456                    .map_err(|e| {
457                        // Technically, these are not the original bytes;
458                        // they've gone through a deserialize-serialize
459                        // round trip. Hopefully they will be close enough to still
460                        // be useful for debugging.
461                        let bytes = val.to_string().into_bytes();
462
463                        DecodeError::BadJson {
464                            category: e.classify(),
465                            bytes,
466                        }
467                    })?;
468            }
469            ValueOrReader::Reader { len, r } => {
470                self.buf.resize_with(len, Default::default);
471                r.read_exact(self.buf)?;
472                JsonbPacker::new(self.packer)
473                    .pack_slice(self.buf)
474                    .map_err(|e| DecodeError::BadJson {
475                        category: e.classify(),
476                        bytes: self.buf.to_owned(),
477                    })?;
478            }
479        }
480        Ok(())
481    }
482    #[inline]
483    fn uuid<'b, R: AvroRead>(
484        self,
485        r: ValueOrReader<'b, &'b [u8], R>,
486    ) -> Result<Self::Out, AvroError> {
487        let buf = match r {
488            ValueOrReader::Value(val) => val,
489            ValueOrReader::Reader { len, r } => {
490                self.buf.resize_with(len, Default::default);
491                r.read_exact(self.buf)?;
492                self.buf
493            }
494        };
495        let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
496        self.packer.push(Datum::Uuid(
497            Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
498        ));
499        Ok(())
500    }
501    #[inline]
502    fn fixed<'b, R: AvroRead>(
503        self,
504        r: ValueOrReader<'b, &'b [u8], R>,
505    ) -> Result<Self::Out, AvroError> {
506        self.bytes(r)
507    }
508    #[inline]
509    fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
510        self.is_top = false;
511        let mut str_buf = std::mem::take(self.buf);
512        self.packer.push_list_with(|rp| -> Result<(), AvroError> {
513            loop {
514                let next = AvroFlatDecoder {
515                    packer: rp,
516                    buf: &mut str_buf,
517                    is_top: false,
518                };
519                if a.decode_next(next)?.is_none() {
520                    break;
521                }
522            }
523            Ok(())
524        })?;
525        *self.buf = str_buf;
526        Ok(())
527    }
528    #[inline]
529    fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
530        // Map (key, value) pairs need to be unique and ordered.
531        let mut map = BTreeMap::new();
532        while let Some((name, f)) = a.next_entry()? {
533            map.insert(name, f.decode_field(ValueDecoder)?);
534        }
535        self.packer
536            .push_dict_with(|packer| -> Result<(), AvroError> {
537                for (key, val) in map {
538                    packer.push(Datum::String(key.as_str()));
539                    give_value(
540                        AvroFlatDecoder {
541                            packer,
542                            buf: &mut vec![],
543                            is_top: false,
544                        },
545                        &val,
546                    )?;
547                }
548                Ok(())
549            })?;
550
551        Ok(())
552    }
553}
554
555#[derive(Clone, Debug)]
556pub struct DiffPair<T> {
557    pub before: Option<T>,
558    pub after: Option<T>,
559}