1use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15 AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16 GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use serde::{Deserialize, Serialize};
26use tracing::trace;
27use uuid::Uuid;
28
29use crate::avro::ConfluentAvroResolver;
30
31#[derive(Debug)]
33pub struct Decoder {
34 csr_avro: ConfluentAvroResolver,
35 debug_name: String,
36 buf1: Vec<u8>,
37 row_buf: Row,
38}
39
40#[cfg(test)]
41mod tests {
42 use mz_ore::assert_err;
43 use mz_repr::{Datum, Row};
44
45 use crate::avro::Decoder;
46
47 #[mz_ore::test(tokio::test)]
48 async fn test_error_followed_by_success() {
49 let schema = r#"{
50"type": "record",
51"name": "test",
52"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
53}"#;
54 let mut decoder = Decoder::new(schema, &[], None, "Test".to_string(), false).unwrap();
55 let mut bad_bytes: &[u8] = &[0];
57 assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
58 let mut good_bytes: &[u8] = &[0, 0];
60 assert_eq!(
62 decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
63 Row::pack([Datum::Int32(0), Datum::Int32(0)])
64 );
65 }
66}
67
68impl Decoder {
69 pub fn new(
79 reader_schema: &str,
80 reader_reference_schemas: &[String],
81 ccsr_client: Option<mz_ccsr::Client>,
82 debug_name: String,
83 confluent_wire_format: bool,
84 ) -> anyhow::Result<Decoder> {
85 let csr_avro = ConfluentAvroResolver::new(
86 reader_schema,
87 reader_reference_schemas,
88 ccsr_client,
89 confluent_wire_format,
90 )?;
91
92 Ok(Decoder {
93 csr_avro,
94 debug_name,
95 buf1: vec![],
96 row_buf: Row::default(),
97 })
98 }
99
100 pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
102 let mut packer = self.row_buf.packer();
107 let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
109 Ok(ok) => ok,
110 Err(err) => return Ok(Err(err)),
111 };
112 *bytes = bytes2;
113 let dec = AvroFlatDecoder {
114 packer: &mut packer,
115 buf: &mut self.buf1,
116 is_top: true,
117 };
118 let dsr = GeneralDeserializer {
119 schema: resolved_schema.top_node(),
120 };
121 let result = dsr
122 .deserialize(bytes, dec)
123 .with_context(|| {
124 format!(
125 "unable to decode row {}",
126 match csr_schema_id {
127 Some(id) => format!("(Avro schema id = {:?})", id),
128 None => "".to_string(),
129 }
130 )
131 })
132 .map(|_| self.row_buf.clone());
133 if result.is_ok() {
134 trace!(
135 "[customer-data] Decoded row {:?} in {}",
136 self.row_buf, self.debug_name
137 );
138 }
139 Ok(result)
140 }
141}
142
143#[derive(Debug)]
144pub struct AvroFlatDecoder<'a, 'row> {
145 pub packer: &'a mut RowPacker<'row>,
146 pub buf: &'a mut Vec<u8>,
147 pub is_top: bool,
148}
149
150impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
151 type Out = ();
152 #[inline]
153 fn record<R: AvroRead, A: AvroRecordAccess<R>>(
154 self,
155 a: &mut A,
156 ) -> Result<Self::Out, AvroError> {
157 let mut str_buf = std::mem::take(self.buf);
158 let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
159 let mut expected = 0;
160 let mut stash = vec![];
161 while let Some((_name, idx, f)) = a.next_field()? {
171 if idx == expected {
172 expected += 1;
173 f.decode_field(AvroFlatDecoder {
174 packer: rp,
175 buf: &mut str_buf,
176 is_top: false,
177 })?;
178 } else {
179 let val = f.decode_field(ValueDecoder)?;
180 stash.push((idx, val));
181 }
182 }
183 stash.sort_by_key(|(idx, _val)| *idx);
184 for (idx, val) in stash {
185 assert!(idx == expected);
186 expected += 1;
187 let dec = AvroFlatDecoder {
188 packer: rp,
189 buf: &mut str_buf,
190 is_top: false,
191 };
192 give_value(dec, &val)?;
193 }
194 Ok(())
195 };
196 if self.is_top {
197 pack_record(self.packer)?;
198 } else {
199 self.packer.push_list_with(pack_record)?;
200 }
201 *self.buf = str_buf;
202 Ok(())
203 }
204 #[inline]
205 fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
206 self,
207 idx: usize,
208 n_variants: usize,
209 null_variant: Option<usize>,
210 deserializer: D,
211 reader: &'b mut R,
212 ) -> Result<Self::Out, AvroError> {
213 if null_variant == Some(idx) {
214 for _ in 0..n_variants - 1 {
215 self.packer.push(Datum::Null)
216 }
217 } else {
218 let mut deserializer = Some(deserializer);
219 for i in 0..n_variants {
220 let dec = AvroFlatDecoder {
221 packer: self.packer,
222 buf: self.buf,
223 is_top: false,
224 };
225 if null_variant != Some(i) {
226 if i == idx {
227 deserializer.take().unwrap().deserialize(reader, dec)?;
228 } else {
229 self.packer.push(Datum::Null)
230 }
231 }
232 }
233 }
234 Ok(())
235 }
236
237 #[inline]
238 fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
239 self.packer.push(Datum::String(symbol));
240 Ok(())
241 }
242 #[inline]
243 fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
244 match scalar {
245 mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
246 mz_avro::types::Scalar::Boolean(val) => {
247 if val {
248 self.packer.push(Datum::True)
249 } else {
250 self.packer.push(Datum::False)
251 }
252 }
253 mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
254 mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
255 mz_avro::types::Scalar::Float(val) => {
256 self.packer.push(Datum::Float32(OrderedFloat(val)))
257 }
258 mz_avro::types::Scalar::Double(val) => {
259 self.packer.push(Datum::Float64(OrderedFloat(val)))
260 }
261 mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
262 Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
263 )),
264 mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
265 CheckedTimestamp::from_timestamplike(val)
266 .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
267 )),
268 }
269 Ok(())
270 }
271
272 #[inline]
273 fn decimal<'b, R: AvroRead>(
274 self,
275 _precision: usize,
276 scale: usize,
277 r: ValueOrReader<'b, &'b [u8], R>,
278 ) -> Result<Self::Out, AvroError> {
279 let mut buf = match r {
280 ValueOrReader::Value(val) => val.to_vec(),
281 ValueOrReader::Reader { len, r } => {
282 self.buf.resize_with(len, Default::default);
283 r.read_exact(self.buf)?;
284 let v = self.buf.clone();
285 v
286 }
287 };
288
289 let scale = u8::try_from(scale).map_err(|_| {
290 DecodeError::Custom(format!(
291 "Error decoding decimal: scale must fit within u8, but got scale {}",
292 scale,
293 ))
294 })?;
295
296 let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
297 .map_err(|e| e.to_string_with_causes())
298 .map_err(DecodeError::Custom)?;
299
300 if n.is_special()
301 || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
302 {
303 return Err(AvroError::Decode(DecodeError::Custom(format!(
304 "Error decoding numeric: exceeds maximum precision {}",
305 numeric::NUMERIC_DATUM_MAX_PRECISION
306 ))));
307 }
308
309 self.packer.push(Datum::from(n));
310
311 Ok(())
312 }
313
314 #[inline]
315 fn bytes<'b, R: AvroRead>(
316 self,
317 r: ValueOrReader<'b, &'b [u8], R>,
318 ) -> Result<Self::Out, AvroError> {
319 let buf = match r {
320 ValueOrReader::Value(val) => val,
321 ValueOrReader::Reader { len, r } => {
322 self.buf.resize_with(len, Default::default);
323 r.read_exact(self.buf)?;
324 self.buf
325 }
326 };
327 self.packer.push(Datum::Bytes(buf));
328 Ok(())
329 }
330 #[inline]
331 fn string<'b, R: AvroRead>(
332 self,
333 r: ValueOrReader<'b, &'b str, R>,
334 ) -> Result<Self::Out, AvroError> {
335 let s = match r {
336 ValueOrReader::Value(val) => val,
337 ValueOrReader::Reader { len, r } => {
338 self.buf.resize_with(len, Default::default);
343 r.read_exact(self.buf)?;
344 std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
345 }
346 };
347 self.packer.push(Datum::String(s));
348 Ok(())
349 }
350 #[inline]
351 fn json<'b, R: AvroRead>(
352 self,
353 r: ValueOrReader<'b, &'b serde_json::Value, R>,
354 ) -> Result<Self::Out, AvroError> {
355 match r {
356 ValueOrReader::Value(val) => {
357 JsonbPacker::new(self.packer)
358 .pack_serde_json(val.clone())
359 .map_err(|e| {
360 let bytes = val.to_string().into_bytes();
365
366 DecodeError::BadJson {
367 category: e.classify(),
368 bytes,
369 }
370 })?;
371 }
372 ValueOrReader::Reader { len, r } => {
373 self.buf.resize_with(len, Default::default);
374 r.read_exact(self.buf)?;
375 JsonbPacker::new(self.packer)
376 .pack_slice(self.buf)
377 .map_err(|e| DecodeError::BadJson {
378 category: e.classify(),
379 bytes: self.buf.to_owned(),
380 })?;
381 }
382 }
383 Ok(())
384 }
385 #[inline]
386 fn uuid<'b, R: AvroRead>(
387 self,
388 r: ValueOrReader<'b, &'b [u8], R>,
389 ) -> Result<Self::Out, AvroError> {
390 let buf = match r {
391 ValueOrReader::Value(val) => val,
392 ValueOrReader::Reader { len, r } => {
393 self.buf.resize_with(len, Default::default);
394 r.read_exact(self.buf)?;
395 self.buf
396 }
397 };
398 let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
399 self.packer.push(Datum::Uuid(
400 Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
401 ));
402 Ok(())
403 }
404 #[inline]
405 fn fixed<'b, R: AvroRead>(
406 self,
407 r: ValueOrReader<'b, &'b [u8], R>,
408 ) -> Result<Self::Out, AvroError> {
409 self.bytes(r)
410 }
411 #[inline]
412 fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
413 self.is_top = false;
414 let mut str_buf = std::mem::take(self.buf);
415 self.packer.push_list_with(|rp| -> Result<(), AvroError> {
416 loop {
417 let next = AvroFlatDecoder {
418 packer: rp,
419 buf: &mut str_buf,
420 is_top: false,
421 };
422 if a.decode_next(next)?.is_none() {
423 break;
424 }
425 }
426 Ok(())
427 })?;
428 *self.buf = str_buf;
429 Ok(())
430 }
431 #[inline]
432 fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
433 let mut map = BTreeMap::new();
435 while let Some((name, f)) = a.next_entry()? {
436 map.insert(name, f.decode_field(ValueDecoder)?);
437 }
438 self.packer
439 .push_dict_with(|packer| -> Result<(), AvroError> {
440 for (key, val) in map {
441 packer.push(Datum::String(key.as_str()));
442 give_value(
443 AvroFlatDecoder {
444 packer,
445 buf: &mut vec![],
446 is_top: false,
447 },
448 &val,
449 )?;
450 }
451 Ok(())
452 })?;
453
454 Ok(())
455 }
456}
457
458#[derive(Clone, Debug, Serialize, Deserialize)]
459pub struct DiffPair<T> {
460 pub before: Option<T>,
461 pub after: Option<T>,
462}