1use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15 AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16 GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use serde::{Deserialize, Serialize};
26use tracing::trace;
27use uuid::Uuid;
28
29use crate::avro::ConfluentAvroResolver;
30
31#[derive(Debug)]
33pub struct Decoder {
34 csr_avro: ConfluentAvroResolver,
35 debug_name: String,
36 buf1: Vec<u8>,
37 row_buf: Row,
38}
39
40#[cfg(test)]
41mod tests {
42 use mz_ore::assert_err;
43 use mz_repr::{Datum, Row};
44
45 use crate::avro::Decoder;
46
47 #[mz_ore::test(tokio::test)]
48 async fn test_error_followed_by_success() {
49 let schema = r#"{
50"type": "record",
51"name": "test",
52"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
53}"#;
54 let mut decoder = Decoder::new(schema, None, "Test".to_string(), false).unwrap();
55 let mut bad_bytes: &[u8] = &[0];
57 assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
58 let mut good_bytes: &[u8] = &[0, 0];
60 assert_eq!(
62 decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
63 Row::pack([Datum::Int32(0), Datum::Int32(0)])
64 );
65 }
66}
67
68impl Decoder {
69 pub fn new(
75 reader_schema: &str,
76 ccsr_client: Option<mz_ccsr::Client>,
77 debug_name: String,
78 confluent_wire_format: bool,
79 ) -> anyhow::Result<Decoder> {
80 let csr_avro =
81 ConfluentAvroResolver::new(reader_schema, ccsr_client, confluent_wire_format)?;
82
83 Ok(Decoder {
84 csr_avro,
85 debug_name,
86 buf1: vec![],
87 row_buf: Row::default(),
88 })
89 }
90
91 pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
93 let mut packer = self.row_buf.packer();
98 let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
100 Ok(ok) => ok,
101 Err(err) => return Ok(Err(err)),
102 };
103 *bytes = bytes2;
104 let dec = AvroFlatDecoder {
105 packer: &mut packer,
106 buf: &mut self.buf1,
107 is_top: true,
108 };
109 let dsr = GeneralDeserializer {
110 schema: resolved_schema.top_node(),
111 };
112 let result = dsr
113 .deserialize(bytes, dec)
114 .with_context(|| {
115 format!(
116 "unable to decode row {}",
117 match csr_schema_id {
118 Some(id) => format!("(Avro schema id = {:?})", id),
119 None => "".to_string(),
120 }
121 )
122 })
123 .map(|_| self.row_buf.clone());
124 if result.is_ok() {
125 trace!(
126 "[customer-data] Decoded row {:?} in {}",
127 self.row_buf, self.debug_name
128 );
129 }
130 Ok(result)
131 }
132}
133
134#[derive(Debug)]
135pub struct AvroFlatDecoder<'a, 'row> {
136 pub packer: &'a mut RowPacker<'row>,
137 pub buf: &'a mut Vec<u8>,
138 pub is_top: bool,
139}
140
141impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
142 type Out = ();
143 #[inline]
144 fn record<R: AvroRead, A: AvroRecordAccess<R>>(
145 self,
146 a: &mut A,
147 ) -> Result<Self::Out, AvroError> {
148 let mut str_buf = std::mem::take(self.buf);
149 let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
150 let mut expected = 0;
151 let mut stash = vec![];
152 while let Some((_name, idx, f)) = a.next_field()? {
162 if idx == expected {
163 expected += 1;
164 f.decode_field(AvroFlatDecoder {
165 packer: rp,
166 buf: &mut str_buf,
167 is_top: false,
168 })?;
169 } else {
170 let val = f.decode_field(ValueDecoder)?;
171 stash.push((idx, val));
172 }
173 }
174 stash.sort_by_key(|(idx, _val)| *idx);
175 for (idx, val) in stash {
176 assert!(idx == expected);
177 expected += 1;
178 let dec = AvroFlatDecoder {
179 packer: rp,
180 buf: &mut str_buf,
181 is_top: false,
182 };
183 give_value(dec, &val)?;
184 }
185 Ok(())
186 };
187 if self.is_top {
188 pack_record(self.packer)?;
189 } else {
190 self.packer.push_list_with(pack_record)?;
191 }
192 *self.buf = str_buf;
193 Ok(())
194 }
195 #[inline]
196 fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
197 self,
198 idx: usize,
199 n_variants: usize,
200 null_variant: Option<usize>,
201 deserializer: D,
202 reader: &'b mut R,
203 ) -> Result<Self::Out, AvroError> {
204 if null_variant == Some(idx) {
205 for _ in 0..n_variants - 1 {
206 self.packer.push(Datum::Null)
207 }
208 } else {
209 let mut deserializer = Some(deserializer);
210 for i in 0..n_variants {
211 let dec = AvroFlatDecoder {
212 packer: self.packer,
213 buf: self.buf,
214 is_top: false,
215 };
216 if null_variant != Some(i) {
217 if i == idx {
218 deserializer.take().unwrap().deserialize(reader, dec)?;
219 } else {
220 self.packer.push(Datum::Null)
221 }
222 }
223 }
224 }
225 Ok(())
226 }
227
228 #[inline]
229 fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
230 self.packer.push(Datum::String(symbol));
231 Ok(())
232 }
233 #[inline]
234 fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
235 match scalar {
236 mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
237 mz_avro::types::Scalar::Boolean(val) => {
238 if val {
239 self.packer.push(Datum::True)
240 } else {
241 self.packer.push(Datum::False)
242 }
243 }
244 mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
245 mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
246 mz_avro::types::Scalar::Float(val) => {
247 self.packer.push(Datum::Float32(OrderedFloat(val)))
248 }
249 mz_avro::types::Scalar::Double(val) => {
250 self.packer.push(Datum::Float64(OrderedFloat(val)))
251 }
252 mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
253 Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
254 )),
255 mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
256 CheckedTimestamp::from_timestamplike(val)
257 .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
258 )),
259 }
260 Ok(())
261 }
262
263 #[inline]
264 fn decimal<'b, R: AvroRead>(
265 self,
266 _precision: usize,
267 scale: usize,
268 r: ValueOrReader<'b, &'b [u8], R>,
269 ) -> Result<Self::Out, AvroError> {
270 let mut buf = match r {
271 ValueOrReader::Value(val) => val.to_vec(),
272 ValueOrReader::Reader { len, r } => {
273 self.buf.resize_with(len, Default::default);
274 r.read_exact(self.buf)?;
275 let v = self.buf.clone();
276 v
277 }
278 };
279
280 let scale = u8::try_from(scale).map_err(|_| {
281 DecodeError::Custom(format!(
282 "Error decoding decimal: scale must fit within u8, but got scale {}",
283 scale,
284 ))
285 })?;
286
287 let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
288 .map_err(|e| e.to_string_with_causes())
289 .map_err(DecodeError::Custom)?;
290
291 if n.is_special()
292 || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
293 {
294 return Err(AvroError::Decode(DecodeError::Custom(format!(
295 "Error decoding numeric: exceeds maximum precision {}",
296 numeric::NUMERIC_DATUM_MAX_PRECISION
297 ))));
298 }
299
300 self.packer.push(Datum::from(n));
301
302 Ok(())
303 }
304
305 #[inline]
306 fn bytes<'b, R: AvroRead>(
307 self,
308 r: ValueOrReader<'b, &'b [u8], R>,
309 ) -> Result<Self::Out, AvroError> {
310 let buf = match r {
311 ValueOrReader::Value(val) => val,
312 ValueOrReader::Reader { len, r } => {
313 self.buf.resize_with(len, Default::default);
314 r.read_exact(self.buf)?;
315 self.buf
316 }
317 };
318 self.packer.push(Datum::Bytes(buf));
319 Ok(())
320 }
321 #[inline]
322 fn string<'b, R: AvroRead>(
323 self,
324 r: ValueOrReader<'b, &'b str, R>,
325 ) -> Result<Self::Out, AvroError> {
326 let s = match r {
327 ValueOrReader::Value(val) => val,
328 ValueOrReader::Reader { len, r } => {
329 self.buf.resize_with(len, Default::default);
334 r.read_exact(self.buf)?;
335 std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
336 }
337 };
338 self.packer.push(Datum::String(s));
339 Ok(())
340 }
341 #[inline]
342 fn json<'b, R: AvroRead>(
343 self,
344 r: ValueOrReader<'b, &'b serde_json::Value, R>,
345 ) -> Result<Self::Out, AvroError> {
346 match r {
347 ValueOrReader::Value(val) => {
348 JsonbPacker::new(self.packer)
349 .pack_serde_json(val.clone())
350 .map_err(|e| {
351 let bytes = val.to_string().into_bytes();
356
357 DecodeError::BadJson {
358 category: e.classify(),
359 bytes,
360 }
361 })?;
362 }
363 ValueOrReader::Reader { len, r } => {
364 self.buf.resize_with(len, Default::default);
365 r.read_exact(self.buf)?;
366 JsonbPacker::new(self.packer)
367 .pack_slice(self.buf)
368 .map_err(|e| DecodeError::BadJson {
369 category: e.classify(),
370 bytes: self.buf.to_owned(),
371 })?;
372 }
373 }
374 Ok(())
375 }
376 #[inline]
377 fn uuid<'b, R: AvroRead>(
378 self,
379 r: ValueOrReader<'b, &'b [u8], R>,
380 ) -> Result<Self::Out, AvroError> {
381 let buf = match r {
382 ValueOrReader::Value(val) => val,
383 ValueOrReader::Reader { len, r } => {
384 self.buf.resize_with(len, Default::default);
385 r.read_exact(self.buf)?;
386 self.buf
387 }
388 };
389 let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
390 self.packer.push(Datum::Uuid(
391 Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
392 ));
393 Ok(())
394 }
395 #[inline]
396 fn fixed<'b, R: AvroRead>(
397 self,
398 r: ValueOrReader<'b, &'b [u8], R>,
399 ) -> Result<Self::Out, AvroError> {
400 self.bytes(r)
401 }
402 #[inline]
403 fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
404 self.is_top = false;
405 let mut str_buf = std::mem::take(self.buf);
406 self.packer.push_list_with(|rp| -> Result<(), AvroError> {
407 loop {
408 let next = AvroFlatDecoder {
409 packer: rp,
410 buf: &mut str_buf,
411 is_top: false,
412 };
413 if a.decode_next(next)?.is_none() {
414 break;
415 }
416 }
417 Ok(())
418 })?;
419 *self.buf = str_buf;
420 Ok(())
421 }
422 #[inline]
423 fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
424 let mut map = BTreeMap::new();
426 while let Some((name, f)) = a.next_entry()? {
427 map.insert(name, f.decode_field(ValueDecoder)?);
428 }
429 self.packer
430 .push_dict_with(|packer| -> Result<(), AvroError> {
431 for (key, val) in map {
432 packer.push(Datum::String(key.as_str()));
433 give_value(
434 AvroFlatDecoder {
435 packer,
436 buf: &mut vec![],
437 is_top: false,
438 },
439 &val,
440 )?;
441 }
442 Ok(())
443 })?;
444
445 Ok(())
446 }
447}
448
449#[derive(Clone, Debug, Serialize, Deserialize)]
450pub struct DiffPair<T> {
451 pub before: Option<T>,
452 pub after: Option<T>,
453}