1use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15 AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16 GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use tracing::trace;
26use uuid::Uuid;
27
28use crate::avro::ConfluentAvroResolver;
29
30#[derive(Debug)]
32pub struct Decoder {
33 csr_avro: ConfluentAvroResolver,
34 debug_name: String,
35 buf1: Vec<u8>,
36 row_buf: Row,
37}
38
39#[cfg(test)]
40mod tests {
41 use mz_ore::assert_err;
42 use mz_repr::{Datum, Row};
43
44 use crate::avro::Decoder;
45
46 #[mz_ore::test(tokio::test)]
47 async fn test_error_followed_by_success() {
48 let schema = r#"{
49"type": "record",
50"name": "test",
51"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
52}"#;
53 let mut decoder = Decoder::new(schema, None, "Test".to_string(), false).unwrap();
54 let mut bad_bytes: &[u8] = &[0];
56 assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
57 let mut good_bytes: &[u8] = &[0, 0];
59 assert_eq!(
61 decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
62 Row::pack([Datum::Int32(0), Datum::Int32(0)])
63 );
64 }
65}
66
67impl Decoder {
68 pub fn new(
74 reader_schema: &str,
75 ccsr_client: Option<mz_ccsr::Client>,
76 debug_name: String,
77 confluent_wire_format: bool,
78 ) -> anyhow::Result<Decoder> {
79 let csr_avro =
80 ConfluentAvroResolver::new(reader_schema, ccsr_client, confluent_wire_format)?;
81
82 Ok(Decoder {
83 csr_avro,
84 debug_name,
85 buf1: vec![],
86 row_buf: Row::default(),
87 })
88 }
89
90 pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
92 let mut packer = self.row_buf.packer();
97 let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
99 Ok(ok) => ok,
100 Err(err) => return Ok(Err(err)),
101 };
102 *bytes = bytes2;
103 let dec = AvroFlatDecoder {
104 packer: &mut packer,
105 buf: &mut self.buf1,
106 is_top: true,
107 };
108 let dsr = GeneralDeserializer {
109 schema: resolved_schema.top_node(),
110 };
111 let result = dsr
112 .deserialize(bytes, dec)
113 .with_context(|| {
114 format!(
115 "unable to decode row {}",
116 match csr_schema_id {
117 Some(id) => format!("(Avro schema id = {:?})", id),
118 None => "".to_string(),
119 }
120 )
121 })
122 .map(|_| self.row_buf.clone());
123 if result.is_ok() {
124 trace!(
125 "[customer-data] Decoded row {:?} in {}",
126 self.row_buf, self.debug_name
127 );
128 }
129 Ok(result)
130 }
131}
132
133#[derive(Debug)]
134pub struct AvroFlatDecoder<'a, 'row> {
135 pub packer: &'a mut RowPacker<'row>,
136 pub buf: &'a mut Vec<u8>,
137 pub is_top: bool,
138}
139
140impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
141 type Out = ();
142 #[inline]
143 fn record<R: AvroRead, A: AvroRecordAccess<R>>(
144 self,
145 a: &mut A,
146 ) -> Result<Self::Out, AvroError> {
147 let mut str_buf = std::mem::take(self.buf);
148 let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
149 let mut expected = 0;
150 let mut stash = vec![];
151 while let Some((_name, idx, f)) = a.next_field()? {
161 if idx == expected {
162 expected += 1;
163 f.decode_field(AvroFlatDecoder {
164 packer: rp,
165 buf: &mut str_buf,
166 is_top: false,
167 })?;
168 } else {
169 let val = f.decode_field(ValueDecoder)?;
170 stash.push((idx, val));
171 }
172 }
173 stash.sort_by_key(|(idx, _val)| *idx);
174 for (idx, val) in stash {
175 assert!(idx == expected);
176 expected += 1;
177 let dec = AvroFlatDecoder {
178 packer: rp,
179 buf: &mut str_buf,
180 is_top: false,
181 };
182 give_value(dec, &val)?;
183 }
184 Ok(())
185 };
186 if self.is_top {
187 pack_record(self.packer)?;
188 } else {
189 self.packer.push_list_with(pack_record)?;
190 }
191 *self.buf = str_buf;
192 Ok(())
193 }
194 #[inline]
195 fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
196 self,
197 idx: usize,
198 n_variants: usize,
199 null_variant: Option<usize>,
200 deserializer: D,
201 reader: &'b mut R,
202 ) -> Result<Self::Out, AvroError> {
203 if null_variant == Some(idx) {
204 for _ in 0..n_variants - 1 {
205 self.packer.push(Datum::Null)
206 }
207 } else {
208 let mut deserializer = Some(deserializer);
209 for i in 0..n_variants {
210 let dec = AvroFlatDecoder {
211 packer: self.packer,
212 buf: self.buf,
213 is_top: false,
214 };
215 if null_variant != Some(i) {
216 if i == idx {
217 deserializer.take().unwrap().deserialize(reader, dec)?;
218 } else {
219 self.packer.push(Datum::Null)
220 }
221 }
222 }
223 }
224 Ok(())
225 }
226
227 #[inline]
228 fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
229 self.packer.push(Datum::String(symbol));
230 Ok(())
231 }
232 #[inline]
233 fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
234 match scalar {
235 mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
236 mz_avro::types::Scalar::Boolean(val) => {
237 if val {
238 self.packer.push(Datum::True)
239 } else {
240 self.packer.push(Datum::False)
241 }
242 }
243 mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
244 mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
245 mz_avro::types::Scalar::Float(val) => {
246 self.packer.push(Datum::Float32(OrderedFloat(val)))
247 }
248 mz_avro::types::Scalar::Double(val) => {
249 self.packer.push(Datum::Float64(OrderedFloat(val)))
250 }
251 mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
252 Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
253 )),
254 mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
255 CheckedTimestamp::from_timestamplike(val)
256 .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
257 )),
258 }
259 Ok(())
260 }
261
262 #[inline]
263 fn decimal<'b, R: AvroRead>(
264 self,
265 _precision: usize,
266 scale: usize,
267 r: ValueOrReader<'b, &'b [u8], R>,
268 ) -> Result<Self::Out, AvroError> {
269 let mut buf = match r {
270 ValueOrReader::Value(val) => val.to_vec(),
271 ValueOrReader::Reader { len, r } => {
272 self.buf.resize_with(len, Default::default);
273 r.read_exact(self.buf)?;
274 let v = self.buf.clone();
275 v
276 }
277 };
278
279 let scale = u8::try_from(scale).map_err(|_| {
280 DecodeError::Custom(format!(
281 "Error decoding decimal: scale must fit within u8, but got scale {}",
282 scale,
283 ))
284 })?;
285
286 let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
287 .map_err(|e| e.to_string_with_causes())
288 .map_err(DecodeError::Custom)?;
289
290 if n.is_special()
291 || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
292 {
293 return Err(AvroError::Decode(DecodeError::Custom(format!(
294 "Error decoding numeric: exceeds maximum precision {}",
295 numeric::NUMERIC_DATUM_MAX_PRECISION
296 ))));
297 }
298
299 self.packer.push(Datum::from(n));
300
301 Ok(())
302 }
303
304 #[inline]
305 fn bytes<'b, R: AvroRead>(
306 self,
307 r: ValueOrReader<'b, &'b [u8], R>,
308 ) -> Result<Self::Out, AvroError> {
309 let buf = match r {
310 ValueOrReader::Value(val) => val,
311 ValueOrReader::Reader { len, r } => {
312 self.buf.resize_with(len, Default::default);
313 r.read_exact(self.buf)?;
314 self.buf
315 }
316 };
317 self.packer.push(Datum::Bytes(buf));
318 Ok(())
319 }
320 #[inline]
321 fn string<'b, R: AvroRead>(
322 self,
323 r: ValueOrReader<'b, &'b str, R>,
324 ) -> Result<Self::Out, AvroError> {
325 let s = match r {
326 ValueOrReader::Value(val) => val,
327 ValueOrReader::Reader { len, r } => {
328 self.buf.resize_with(len, Default::default);
333 r.read_exact(self.buf)?;
334 std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
335 }
336 };
337 self.packer.push(Datum::String(s));
338 Ok(())
339 }
340 #[inline]
341 fn json<'b, R: AvroRead>(
342 self,
343 r: ValueOrReader<'b, &'b serde_json::Value, R>,
344 ) -> Result<Self::Out, AvroError> {
345 match r {
346 ValueOrReader::Value(val) => {
347 JsonbPacker::new(self.packer)
348 .pack_serde_json(val.clone())
349 .map_err(|e| {
350 let bytes = val.to_string().into_bytes();
355
356 DecodeError::BadJson {
357 category: e.classify(),
358 bytes,
359 }
360 })?;
361 }
362 ValueOrReader::Reader { len, r } => {
363 self.buf.resize_with(len, Default::default);
364 r.read_exact(self.buf)?;
365 JsonbPacker::new(self.packer)
366 .pack_slice(self.buf)
367 .map_err(|e| DecodeError::BadJson {
368 category: e.classify(),
369 bytes: self.buf.to_owned(),
370 })?;
371 }
372 }
373 Ok(())
374 }
375 #[inline]
376 fn uuid<'b, R: AvroRead>(
377 self,
378 r: ValueOrReader<'b, &'b [u8], R>,
379 ) -> Result<Self::Out, AvroError> {
380 let buf = match r {
381 ValueOrReader::Value(val) => val,
382 ValueOrReader::Reader { len, r } => {
383 self.buf.resize_with(len, Default::default);
384 r.read_exact(self.buf)?;
385 self.buf
386 }
387 };
388 let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
389 self.packer.push(Datum::Uuid(
390 Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
391 ));
392 Ok(())
393 }
394 #[inline]
395 fn fixed<'b, R: AvroRead>(
396 self,
397 r: ValueOrReader<'b, &'b [u8], R>,
398 ) -> Result<Self::Out, AvroError> {
399 self.bytes(r)
400 }
401 #[inline]
402 fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
403 self.is_top = false;
404 let mut str_buf = std::mem::take(self.buf);
405 self.packer.push_list_with(|rp| -> Result<(), AvroError> {
406 loop {
407 let next = AvroFlatDecoder {
408 packer: rp,
409 buf: &mut str_buf,
410 is_top: false,
411 };
412 if a.decode_next(next)?.is_none() {
413 break;
414 }
415 }
416 Ok(())
417 })?;
418 *self.buf = str_buf;
419 Ok(())
420 }
421 #[inline]
422 fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
423 let mut map = BTreeMap::new();
425 while let Some((name, f)) = a.next_entry()? {
426 map.insert(name, f.decode_field(ValueDecoder)?);
427 }
428 self.packer
429 .push_dict_with(|packer| -> Result<(), AvroError> {
430 for (key, val) in map {
431 packer.push(Datum::String(key.as_str()));
432 give_value(
433 AvroFlatDecoder {
434 packer,
435 buf: &mut vec![],
436 is_top: false,
437 },
438 &val,
439 )?;
440 }
441 Ok(())
442 })?;
443
444 Ok(())
445 }
446}
447
448#[derive(Clone, Debug)]
449pub struct DiffPair<T> {
450 pub before: Option<T>,
451 pub after: Option<T>,
452}