1use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::io::Read;
13use std::rc::Rc;
14
15use anyhow::{Context, Error};
16use mz_avro::error::{DecodeError, Error as AvroError};
17use mz_avro::{
18 AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
19 GeneralDeserializer, StatefulAvroDecodable, ValueDecoder, ValueOrReader, define_unexpected,
20 give_value,
21};
22use mz_ore::error::ErrorExt;
23use mz_repr::adt::date::Date;
24use mz_repr::adt::jsonb::JsonbPacker;
25use mz_repr::adt::numeric;
26use mz_repr::adt::timestamp::CheckedTimestamp;
27use mz_repr::{Datum, Row, RowPacker};
28use ordered_float::OrderedFloat;
29use tracing::trace;
30use uuid::Uuid;
31
32use crate::avro::ConfluentAvroResolver;
33
34#[derive(Debug)]
36pub struct Decoder {
37 csr_avro: ConfluentAvroResolver,
38 debug_name: String,
39 buf1: Vec<u8>,
40 row_buf: Row,
41}
42
43#[cfg(test)]
44mod tests {
45 use mz_ore::assert_err;
46 use mz_repr::{Datum, Row};
47
48 use crate::avro::Decoder;
49
50 #[mz_ore::test(tokio::test)]
51 async fn test_error_followed_by_success() {
52 let schema = r#"{
53"type": "record",
54"name": "test",
55"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
56}"#;
57 let mut decoder = Decoder::new(schema, None, "Test".to_string(), false).unwrap();
58 let mut bad_bytes: &[u8] = &[0];
60 assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
61 let mut good_bytes: &[u8] = &[0, 0];
63 assert_eq!(
65 decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
66 Row::pack([Datum::Int32(0), Datum::Int32(0)])
67 );
68 }
69}
70
71impl Decoder {
72 pub fn new(
78 reader_schema: &str,
79 ccsr_client: Option<mz_ccsr::Client>,
80 debug_name: String,
81 confluent_wire_format: bool,
82 ) -> anyhow::Result<Decoder> {
83 let csr_avro =
84 ConfluentAvroResolver::new(reader_schema, ccsr_client, confluent_wire_format)?;
85
86 Ok(Decoder {
87 csr_avro,
88 debug_name,
89 buf1: vec![],
90 row_buf: Row::default(),
91 })
92 }
93
94 pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
96 let mut packer = self.row_buf.packer();
101 let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
103 Ok(ok) => ok,
104 Err(err) => return Ok(Err(err)),
105 };
106 *bytes = bytes2;
107 let dec = AvroFlatDecoder {
108 packer: &mut packer,
109 buf: &mut self.buf1,
110 is_top: true,
111 };
112 let dsr = GeneralDeserializer {
113 schema: resolved_schema.top_node(),
114 };
115 let result = dsr
116 .deserialize(bytes, dec)
117 .with_context(|| {
118 format!(
119 "unable to decode row {}",
120 match csr_schema_id {
121 Some(id) => format!("(Avro schema id = {:?})", id),
122 None => "".to_string(),
123 }
124 )
125 })
126 .map(|_| self.row_buf.clone());
127 if result.is_ok() {
128 trace!(
129 "[customer-data] Decoded row {:?} in {}",
130 self.row_buf, self.debug_name
131 );
132 }
133 Ok(result)
134 }
135}
136
137pub struct AvroStringDecoder<'a> {
138 pub buf: &'a mut Vec<u8>,
139}
140
141impl<'a> AvroDecode for AvroStringDecoder<'a> {
142 type Out = ();
143 fn string<'b, R: AvroRead>(
144 self,
145 r: ValueOrReader<'b, &'b str, R>,
146 ) -> Result<Self::Out, AvroError> {
147 match r {
148 ValueOrReader::Value(val) => {
149 self.buf.resize_with(val.len(), Default::default);
150 val.as_bytes().read_exact(self.buf)?;
151 }
152 ValueOrReader::Reader { len, r } => {
153 self.buf.resize_with(len, Default::default);
154 r.read_exact(self.buf)?;
155 }
156 }
157 Ok(())
158 }
159 define_unexpected! {
160 record, union_branch, array, map, enum_variant, scalar, decimal, bytes, json, uuid, fixed
161 }
162}
163
164#[allow(dead_code)]
166pub(super) struct OptionalRecordDecoder<'a, 'row> {
167 pub packer: &'a mut RowPacker<'row>,
168 pub buf: &'a mut Vec<u8>,
169}
170
171impl<'a, 'row> AvroDecode for OptionalRecordDecoder<'a, 'row> {
172 type Out = bool;
173 fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
174 self,
175 idx: usize,
176 _n_variants: usize,
177 null_variant: Option<usize>,
178 deserializer: D,
179 reader: &'b mut R,
180 ) -> Result<Self::Out, AvroError> {
181 if Some(idx) == null_variant {
182 Ok(false)
184 } else {
185 let d = AvroFlatDecoder {
186 packer: self.packer,
187 buf: self.buf,
188 is_top: false,
189 };
190 deserializer.deserialize(reader, d)?;
191 Ok(true)
192 }
193 }
194 define_unexpected! {
195 record, array, map, enum_variant, scalar, decimal, bytes, string, json, uuid, fixed
196 }
197}
198
199pub(super) struct RowDecoder {
200 state: (Rc<RefCell<Row>>, Rc<RefCell<Vec<u8>>>),
201}
202
203impl AvroDecode for RowDecoder {
204 type Out = RowWrapper;
205 fn record<R: AvroRead, A: AvroRecordAccess<R>>(
206 self,
207 a: &mut A,
208 ) -> Result<Self::Out, AvroError> {
209 let mut row_borrow = self.state.0.borrow_mut();
210 let mut buf_borrow = self.state.1.borrow_mut();
211 let mut packer = row_borrow.packer();
212 let inner = AvroFlatDecoder {
213 packer: &mut packer,
214 buf: &mut buf_borrow,
215 is_top: true,
216 };
217 inner.record(a)?;
218 Ok(RowWrapper(row_borrow.clone()))
219 }
220 define_unexpected! {
221 union_branch, array, map, enum_variant, scalar, decimal, bytes, string, json, uuid, fixed
222 }
223}
224
225#[derive(Debug)]
227pub(super) struct RowWrapper(#[allow(dead_code)] pub Row);
228
229impl StatefulAvroDecodable for RowWrapper {
230 type Decoder = RowDecoder;
231 type State = (Rc<RefCell<Row>>, Rc<RefCell<Vec<u8>>>);
234
235 fn new_decoder(state: Self::State) -> Self::Decoder {
236 Self::Decoder { state }
237 }
238}
239
240#[derive(Debug)]
241pub struct AvroFlatDecoder<'a, 'row> {
242 pub packer: &'a mut RowPacker<'row>,
243 pub buf: &'a mut Vec<u8>,
244 pub is_top: bool,
245}
246
247impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
248 type Out = ();
249 #[inline]
250 fn record<R: AvroRead, A: AvroRecordAccess<R>>(
251 self,
252 a: &mut A,
253 ) -> Result<Self::Out, AvroError> {
254 let mut str_buf = std::mem::take(self.buf);
255 let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
256 let mut expected = 0;
257 let mut stash = vec![];
258 while let Some((_name, idx, f)) = a.next_field()? {
268 if idx == expected {
269 expected += 1;
270 f.decode_field(AvroFlatDecoder {
271 packer: rp,
272 buf: &mut str_buf,
273 is_top: false,
274 })?;
275 } else {
276 let val = f.decode_field(ValueDecoder)?;
277 stash.push((idx, val));
278 }
279 }
280 stash.sort_by_key(|(idx, _val)| *idx);
281 for (idx, val) in stash {
282 assert!(idx == expected);
283 expected += 1;
284 let dec = AvroFlatDecoder {
285 packer: rp,
286 buf: &mut str_buf,
287 is_top: false,
288 };
289 give_value(dec, &val)?;
290 }
291 Ok(())
292 };
293 if self.is_top {
294 pack_record(self.packer)?;
295 } else {
296 self.packer.push_list_with(pack_record)?;
297 }
298 *self.buf = str_buf;
299 Ok(())
300 }
301 #[inline]
302 fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
303 self,
304 idx: usize,
305 n_variants: usize,
306 null_variant: Option<usize>,
307 deserializer: D,
308 reader: &'b mut R,
309 ) -> Result<Self::Out, AvroError> {
310 if null_variant == Some(idx) {
311 for _ in 0..n_variants - 1 {
312 self.packer.push(Datum::Null)
313 }
314 } else {
315 let mut deserializer = Some(deserializer);
316 for i in 0..n_variants {
317 let dec = AvroFlatDecoder {
318 packer: self.packer,
319 buf: self.buf,
320 is_top: false,
321 };
322 if null_variant != Some(i) {
323 if i == idx {
324 deserializer.take().unwrap().deserialize(reader, dec)?;
325 } else {
326 self.packer.push(Datum::Null)
327 }
328 }
329 }
330 }
331 Ok(())
332 }
333
334 #[inline]
335 fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
336 self.packer.push(Datum::String(symbol));
337 Ok(())
338 }
339 #[inline]
340 fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
341 match scalar {
342 mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
343 mz_avro::types::Scalar::Boolean(val) => {
344 if val {
345 self.packer.push(Datum::True)
346 } else {
347 self.packer.push(Datum::False)
348 }
349 }
350 mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
351 mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
352 mz_avro::types::Scalar::Float(val) => {
353 self.packer.push(Datum::Float32(OrderedFloat(val)))
354 }
355 mz_avro::types::Scalar::Double(val) => {
356 self.packer.push(Datum::Float64(OrderedFloat(val)))
357 }
358 mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
359 Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
360 )),
361 mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
362 CheckedTimestamp::from_timestamplike(val)
363 .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
364 )),
365 }
366 Ok(())
367 }
368
369 #[inline]
370 fn decimal<'b, R: AvroRead>(
371 self,
372 _precision: usize,
373 scale: usize,
374 r: ValueOrReader<'b, &'b [u8], R>,
375 ) -> Result<Self::Out, AvroError> {
376 let mut buf = match r {
377 ValueOrReader::Value(val) => val.to_vec(),
378 ValueOrReader::Reader { len, r } => {
379 self.buf.resize_with(len, Default::default);
380 r.read_exact(self.buf)?;
381 let v = self.buf.clone();
382 v
383 }
384 };
385
386 let scale = u8::try_from(scale).map_err(|_| {
387 DecodeError::Custom(format!(
388 "Error decoding decimal: scale must fit within u8, but got scale {}",
389 scale,
390 ))
391 })?;
392
393 let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
394 .map_err(|e| e.to_string_with_causes())
395 .map_err(DecodeError::Custom)?;
396
397 if n.is_special()
398 || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
399 {
400 return Err(AvroError::Decode(DecodeError::Custom(format!(
401 "Error decoding numeric: exceeds maximum precision {}",
402 numeric::NUMERIC_DATUM_MAX_PRECISION
403 ))));
404 }
405
406 self.packer.push(Datum::from(n));
407
408 Ok(())
409 }
410
411 #[inline]
412 fn bytes<'b, R: AvroRead>(
413 self,
414 r: ValueOrReader<'b, &'b [u8], R>,
415 ) -> Result<Self::Out, AvroError> {
416 let buf = match r {
417 ValueOrReader::Value(val) => val,
418 ValueOrReader::Reader { len, r } => {
419 self.buf.resize_with(len, Default::default);
420 r.read_exact(self.buf)?;
421 self.buf
422 }
423 };
424 self.packer.push(Datum::Bytes(buf));
425 Ok(())
426 }
427 #[inline]
428 fn string<'b, R: AvroRead>(
429 self,
430 r: ValueOrReader<'b, &'b str, R>,
431 ) -> Result<Self::Out, AvroError> {
432 let s = match r {
433 ValueOrReader::Value(val) => val,
434 ValueOrReader::Reader { len, r } => {
435 self.buf.resize_with(len, Default::default);
440 r.read_exact(self.buf)?;
441 std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
442 }
443 };
444 self.packer.push(Datum::String(s));
445 Ok(())
446 }
447 #[inline]
448 fn json<'b, R: AvroRead>(
449 self,
450 r: ValueOrReader<'b, &'b serde_json::Value, R>,
451 ) -> Result<Self::Out, AvroError> {
452 match r {
453 ValueOrReader::Value(val) => {
454 JsonbPacker::new(self.packer)
455 .pack_serde_json(val.clone())
456 .map_err(|e| {
457 let bytes = val.to_string().into_bytes();
462
463 DecodeError::BadJson {
464 category: e.classify(),
465 bytes,
466 }
467 })?;
468 }
469 ValueOrReader::Reader { len, r } => {
470 self.buf.resize_with(len, Default::default);
471 r.read_exact(self.buf)?;
472 JsonbPacker::new(self.packer)
473 .pack_slice(self.buf)
474 .map_err(|e| DecodeError::BadJson {
475 category: e.classify(),
476 bytes: self.buf.to_owned(),
477 })?;
478 }
479 }
480 Ok(())
481 }
482 #[inline]
483 fn uuid<'b, R: AvroRead>(
484 self,
485 r: ValueOrReader<'b, &'b [u8], R>,
486 ) -> Result<Self::Out, AvroError> {
487 let buf = match r {
488 ValueOrReader::Value(val) => val,
489 ValueOrReader::Reader { len, r } => {
490 self.buf.resize_with(len, Default::default);
491 r.read_exact(self.buf)?;
492 self.buf
493 }
494 };
495 let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
496 self.packer.push(Datum::Uuid(
497 Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
498 ));
499 Ok(())
500 }
501 #[inline]
502 fn fixed<'b, R: AvroRead>(
503 self,
504 r: ValueOrReader<'b, &'b [u8], R>,
505 ) -> Result<Self::Out, AvroError> {
506 self.bytes(r)
507 }
508 #[inline]
509 fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
510 self.is_top = false;
511 let mut str_buf = std::mem::take(self.buf);
512 self.packer.push_list_with(|rp| -> Result<(), AvroError> {
513 loop {
514 let next = AvroFlatDecoder {
515 packer: rp,
516 buf: &mut str_buf,
517 is_top: false,
518 };
519 if a.decode_next(next)?.is_none() {
520 break;
521 }
522 }
523 Ok(())
524 })?;
525 *self.buf = str_buf;
526 Ok(())
527 }
528 #[inline]
529 fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
530 let mut map = BTreeMap::new();
532 while let Some((name, f)) = a.next_entry()? {
533 map.insert(name, f.decode_field(ValueDecoder)?);
534 }
535 self.packer
536 .push_dict_with(|packer| -> Result<(), AvroError> {
537 for (key, val) in map {
538 packer.push(Datum::String(key.as_str()));
539 give_value(
540 AvroFlatDecoder {
541 packer,
542 buf: &mut vec![],
543 is_top: false,
544 },
545 &val,
546 )?;
547 }
548 Ok(())
549 })?;
550
551 Ok(())
552 }
553}
554
555#[derive(Clone, Debug)]
556pub struct DiffPair<T> {
557 pub before: Option<T>,
558 pub after: Option<T>,
559}