1use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15 AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16 GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use serde::{Deserialize, Serialize};
26use tracing::trace;
27use uuid::Uuid;
28
29use crate::avro::AvroSchemaResolver;
30
31#[derive(Debug)]
33pub struct Decoder {
34 csr_avro: AvroSchemaResolver,
35 debug_name: String,
36 buf1: Vec<u8>,
37 row_buf: Row,
38}
39
40#[cfg(test)]
41mod tests {
42 use mz_ore::assert_err;
43 use mz_repr::{Datum, Row};
44
45 use crate::avro::{Decoder, WriterSchemaProvider};
46
47 #[mz_ore::test(tokio::test)]
48 async fn test_error_followed_by_success() {
49 let schema = r#"{
50"type": "record",
51"name": "test",
52"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
53}"#;
54 let mut decoder =
55 Decoder::new(schema, &[], WriterSchemaProvider::None, "Test".to_string()).unwrap();
56 let mut bad_bytes: &[u8] = &[0];
58 assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
59 let mut good_bytes: &[u8] = &[0, 0];
61 assert_eq!(
63 decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
64 Row::pack([Datum::Int32(0), Datum::Int32(0)])
65 );
66 }
67}
68
69impl Decoder {
70 pub fn new(
80 reader_schema: &str,
81 reader_reference_schemas: &[String],
82 writer_schemas: crate::avro::WriterSchemaProvider,
83 debug_name: String,
84 ) -> anyhow::Result<Decoder> {
85 let csr_avro =
86 AvroSchemaResolver::new(reader_schema, reader_reference_schemas, writer_schemas)?;
87
88 Ok(Decoder {
89 csr_avro,
90 debug_name,
91 buf1: vec![],
92 row_buf: Row::default(),
93 })
94 }
95
96 pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
98 let mut packer = self.row_buf.packer();
103 let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
105 Ok(ok) => ok,
106 Err(err) => return Ok(Err(err)),
107 };
108 *bytes = bytes2;
109 let dec = AvroFlatDecoder {
110 packer: &mut packer,
111 buf: &mut self.buf1,
112 is_top: true,
113 };
114 let dsr = GeneralDeserializer {
115 schema: resolved_schema.top_node(),
116 };
117 let result = dsr
118 .deserialize(bytes, dec)
119 .with_context(|| {
120 format!(
121 "unable to decode row {}",
122 match &csr_schema_id {
123 Some(id) => format!("(Avro schema id = {:?})", id),
124 None => "".to_string(),
125 }
126 )
127 })
128 .map(|_| self.row_buf.clone());
129 if result.is_ok() {
130 trace!(
131 "[customer-data] Decoded row {:?} in {}",
132 self.row_buf, self.debug_name
133 );
134 }
135 Ok(result)
136 }
137}
138
139#[derive(Debug)]
140pub struct AvroFlatDecoder<'a, 'row> {
141 pub packer: &'a mut RowPacker<'row>,
142 pub buf: &'a mut Vec<u8>,
143 pub is_top: bool,
144}
145
146impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
147 type Out = ();
148 #[inline]
149 fn record<R: AvroRead, A: AvroRecordAccess<R>>(
150 self,
151 a: &mut A,
152 ) -> Result<Self::Out, AvroError> {
153 let mut str_buf = std::mem::take(self.buf);
154 let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
155 let mut expected = 0;
156 let mut stash = vec![];
157 while let Some((_name, idx, f)) = a.next_field()? {
167 if idx == expected {
168 expected += 1;
169 f.decode_field(AvroFlatDecoder {
170 packer: rp,
171 buf: &mut str_buf,
172 is_top: false,
173 })?;
174 } else {
175 let val = f.decode_field(ValueDecoder)?;
176 stash.push((idx, val));
177 }
178 }
179 stash.sort_by_key(|(idx, _val)| *idx);
180 for (idx, val) in stash {
181 assert!(idx == expected);
182 expected += 1;
183 let dec = AvroFlatDecoder {
184 packer: rp,
185 buf: &mut str_buf,
186 is_top: false,
187 };
188 give_value(dec, &val)?;
189 }
190 Ok(())
191 };
192 if self.is_top {
193 pack_record(self.packer)?;
194 } else {
195 self.packer.push_list_with(pack_record)?;
196 }
197 *self.buf = str_buf;
198 Ok(())
199 }
200 #[inline]
201 fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
202 self,
203 idx: usize,
204 n_variants: usize,
205 null_variant: Option<usize>,
206 deserializer: D,
207 reader: &'b mut R,
208 ) -> Result<Self::Out, AvroError> {
209 if null_variant == Some(idx) {
210 for _ in 0..n_variants - 1 {
211 self.packer.push(Datum::Null)
212 }
213 } else {
214 let mut deserializer = Some(deserializer);
215 for i in 0..n_variants {
216 let dec = AvroFlatDecoder {
217 packer: self.packer,
218 buf: self.buf,
219 is_top: false,
220 };
221 if null_variant != Some(i) {
222 if i == idx {
223 deserializer.take().unwrap().deserialize(reader, dec)?;
224 } else {
225 self.packer.push(Datum::Null)
226 }
227 }
228 }
229 }
230 Ok(())
231 }
232
233 #[inline]
234 fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
235 self.packer.push(Datum::String(symbol));
236 Ok(())
237 }
238 #[inline]
239 fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
240 match scalar {
241 mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
242 mz_avro::types::Scalar::Boolean(val) => {
243 if val {
244 self.packer.push(Datum::True)
245 } else {
246 self.packer.push(Datum::False)
247 }
248 }
249 mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
250 mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
251 mz_avro::types::Scalar::Float(val) => {
252 self.packer.push(Datum::Float32(OrderedFloat(val)))
253 }
254 mz_avro::types::Scalar::Double(val) => {
255 self.packer.push(Datum::Float64(OrderedFloat(val)))
256 }
257 mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
258 Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
259 )),
260 mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
261 CheckedTimestamp::from_timestamplike(val)
262 .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
263 )),
264 }
265 Ok(())
266 }
267
268 #[inline]
269 fn decimal<'b, R: AvroRead>(
270 self,
271 _precision: usize,
272 scale: usize,
273 r: ValueOrReader<'b, &'b [u8], R>,
274 ) -> Result<Self::Out, AvroError> {
275 let mut buf = match r {
276 ValueOrReader::Value(val) => val.to_vec(),
277 ValueOrReader::Reader { len, r } => {
278 self.buf.resize_with(len, Default::default);
279 r.read_exact(self.buf)?;
280 let v = self.buf.clone();
281 v
282 }
283 };
284
285 let scale = u8::try_from(scale).map_err(|_| {
286 DecodeError::Custom(format!(
287 "Error decoding decimal: scale must fit within u8, but got scale {}",
288 scale,
289 ))
290 })?;
291
292 let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
293 .map_err(|e| e.to_string_with_causes())
294 .map_err(DecodeError::Custom)?;
295
296 if n.is_special()
297 || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
298 {
299 return Err(AvroError::Decode(DecodeError::Custom(format!(
300 "Error decoding numeric: exceeds maximum precision {}",
301 numeric::NUMERIC_DATUM_MAX_PRECISION
302 ))));
303 }
304
305 self.packer.push(Datum::from(n));
306
307 Ok(())
308 }
309
310 #[inline]
311 fn bytes<'b, R: AvroRead>(
312 self,
313 r: ValueOrReader<'b, &'b [u8], R>,
314 ) -> Result<Self::Out, AvroError> {
315 let buf = match r {
316 ValueOrReader::Value(val) => val,
317 ValueOrReader::Reader { len, r } => {
318 self.buf.resize_with(len, Default::default);
319 r.read_exact(self.buf)?;
320 self.buf
321 }
322 };
323 self.packer.push(Datum::Bytes(buf));
324 Ok(())
325 }
326 #[inline]
327 fn string<'b, R: AvroRead>(
328 self,
329 r: ValueOrReader<'b, &'b str, R>,
330 ) -> Result<Self::Out, AvroError> {
331 let s = match r {
332 ValueOrReader::Value(val) => val,
333 ValueOrReader::Reader { len, r } => {
334 self.buf.resize_with(len, Default::default);
339 r.read_exact(self.buf)?;
340 std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
341 }
342 };
343 self.packer.push(Datum::String(s));
344 Ok(())
345 }
346 #[inline]
347 fn json<'b, R: AvroRead>(
348 self,
349 r: ValueOrReader<'b, &'b serde_json::Value, R>,
350 ) -> Result<Self::Out, AvroError> {
351 match r {
352 ValueOrReader::Value(val) => {
353 JsonbPacker::new(self.packer)
354 .pack_serde_json(val.clone())
355 .map_err(|e| {
356 let bytes = val.to_string().into_bytes();
361
362 DecodeError::BadJson {
363 category: e.classify(),
364 bytes,
365 }
366 })?;
367 }
368 ValueOrReader::Reader { len, r } => {
369 self.buf.resize_with(len, Default::default);
370 r.read_exact(self.buf)?;
371 JsonbPacker::new(self.packer)
372 .pack_slice(self.buf)
373 .map_err(|e| DecodeError::BadJson {
374 category: e.classify(),
375 bytes: self.buf.to_owned(),
376 })?;
377 }
378 }
379 Ok(())
380 }
381 #[inline]
382 fn uuid<'b, R: AvroRead>(
383 self,
384 r: ValueOrReader<'b, &'b [u8], R>,
385 ) -> Result<Self::Out, AvroError> {
386 let buf = match r {
387 ValueOrReader::Value(val) => val,
388 ValueOrReader::Reader { len, r } => {
389 self.buf.resize_with(len, Default::default);
390 r.read_exact(self.buf)?;
391 self.buf
392 }
393 };
394 let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
395 self.packer.push(Datum::Uuid(
396 Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
397 ));
398 Ok(())
399 }
400 #[inline]
401 fn fixed<'b, R: AvroRead>(
402 self,
403 r: ValueOrReader<'b, &'b [u8], R>,
404 ) -> Result<Self::Out, AvroError> {
405 self.bytes(r)
406 }
407 #[inline]
408 fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
409 self.is_top = false;
410 let mut str_buf = std::mem::take(self.buf);
411 self.packer.push_list_with(|rp| -> Result<(), AvroError> {
412 loop {
413 let next = AvroFlatDecoder {
414 packer: rp,
415 buf: &mut str_buf,
416 is_top: false,
417 };
418 if a.decode_next(next)?.is_none() {
419 break;
420 }
421 }
422 Ok(())
423 })?;
424 *self.buf = str_buf;
425 Ok(())
426 }
427 #[inline]
428 fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
429 let mut map = BTreeMap::new();
431 while let Some((name, f)) = a.next_entry()? {
432 map.insert(name, f.decode_field(ValueDecoder)?);
433 }
434 self.packer
435 .push_dict_with(|packer| -> Result<(), AvroError> {
436 for (key, val) in map {
437 packer.push(Datum::String(key.as_str()));
438 give_value(
439 AvroFlatDecoder {
440 packer,
441 buf: &mut vec![],
442 is_top: false,
443 },
444 &val,
445 )?;
446 }
447 Ok(())
448 })?;
449
450 Ok(())
451 }
452}
453
454#[derive(Clone, Debug, Serialize, Deserialize)]
455pub struct DiffPair<T> {
456 pub before: Option<T>,
457 pub after: Option<T>,
458}