mz_testdrive/format/
avro.rs

1// Copyright 2018 Flavien Raynaud
2// Copyright Materialize, Inc. and contributors. All rights reserved.
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE file.
6//
7// As of the Change Date specified in that file, in accordance with
8// the Business Source License, use of this software will be governed
9// by the Apache License, Version 2.0.
10//
11// Portions of this file are derived from the ToAvro implementation for
12// serde_json::Value that is shipped with the avro_rs project. The original
13// source code was retrieved on April 25, 2019 from:
14//
15//     https://github.com/flavray/avro-rs/blob/c4971ac08f52750db6bc95559c2b5faa6c0c9a06/src/types.rs
16//
17// The original source code is subject to the terms of the MIT license, a copy
18// of which can be found in the LICENSE file at the root of this repository.
19
20use std::collections::BTreeMap;
21use std::convert::{TryFrom, TryInto};
22use std::fmt::Debug;
23
24use anyhow::{Context, anyhow, bail};
25use byteorder::{BigEndian, ByteOrder};
26use chrono::NaiveDate;
27// Re-export components from the various other Avro libraries, so that other
28// testdrive modules can import just this one.
29pub use mz_avro::schema::{Schema, SchemaKind, SchemaNode, SchemaPiece, SchemaPieceOrNamed};
30pub use mz_avro::types::{DecimalValue, ToAvro, Value};
31pub use mz_avro::{from_avro_datum, to_avro_datum};
32pub use mz_interchange::avro::parse_schema;
33use serde_json::Value as JsonValue;
34
35// This function is derived from code in the avro_rs project. Update the license
36// header on this file accordingly if you move it to a new home.
37pub fn from_json(json: &JsonValue, schema: SchemaNode) -> Result<Value, anyhow::Error> {
38    match (&json, &schema.inner) {
39        (JsonValue::Null, SchemaPiece::Null) => Ok(Value::Null),
40        (JsonValue::Bool(b), SchemaPiece::Boolean) => Ok(Value::Boolean(*b)),
41        (JsonValue::Number(n), SchemaPiece::Int) => Ok(Value::Int(n.as_i64().unwrap().try_into()?)),
42        (JsonValue::Number(n), SchemaPiece::Long) => Ok(Value::Long(n.as_i64().unwrap())),
43        (JsonValue::Number(n), SchemaPiece::Float) => {
44            // No other known way to cast an `f64` to an `f32`.
45            #[allow(clippy::as_conversions)]
46            Ok(Value::Float(n.as_f64().unwrap() as f32))
47        }
48        (JsonValue::Number(n), SchemaPiece::Double) => Ok(Value::Double(n.as_f64().unwrap())),
49        (JsonValue::Number(n), SchemaPiece::Date) => {
50            Ok(Value::Date(i32::try_from(n.as_i64().unwrap())?))
51        }
52        (JsonValue::Number(n), SchemaPiece::TimestampMilli) => {
53            let ts = n.as_i64().unwrap();
54            Ok(Value::Timestamp(
55                chrono::DateTime::from_timestamp_millis(ts)
56                    .ok_or_else(|| anyhow!("timestamp out of bounds"))?
57                    .naive_utc(),
58            ))
59        }
60        (JsonValue::Number(n), SchemaPiece::TimestampMicro) => {
61            let ts = n.as_i64().unwrap();
62            Ok(Value::Timestamp(
63                chrono::DateTime::from_timestamp_micros(ts)
64                    .ok_or_else(|| anyhow!("timestamp out of bounds"))?
65                    .naive_utc(),
66            ))
67        }
68        (JsonValue::Array(items), SchemaPiece::Array(inner)) => Ok(Value::Array(
69            items
70                .iter()
71                .map(|x| from_json(x, schema.step(&**inner)))
72                .collect::<Result<_, _>>()?,
73        )),
74        (JsonValue::String(s), SchemaPiece::String) => Ok(Value::String(s.clone())),
75        (
76            JsonValue::Array(items),
77            SchemaPiece::Decimal {
78                precision, scale, ..
79            },
80        ) => {
81            let bytes = match items
82                .iter()
83                .map(|x| x.as_i64().and_then(|x| u8::try_from(x).ok()))
84                .collect::<Option<Vec<u8>>>()
85            {
86                Some(bytes) => bytes,
87                None => bail!("decimal was not represented by byte array"),
88            };
89            Ok(Value::Decimal(DecimalValue {
90                unscaled: bytes,
91                precision: *precision,
92                scale: *scale,
93            }))
94        }
95        (JsonValue::Array(items), SchemaPiece::Fixed { size }) => {
96            let bytes = match items
97                .iter()
98                .map(|x| x.as_i64().and_then(|x| u8::try_from(x).ok()))
99                .collect::<Option<Vec<u8>>>()
100            {
101                Some(bytes) => bytes,
102                None => bail!("fixed was not represented by byte array"),
103            };
104            if *size != bytes.len() {
105                bail!("expected fixed size {}, got {}", *size, bytes.len())
106            } else {
107                Ok(Value::Fixed(*size, bytes))
108            }
109        }
110        (JsonValue::String(s), SchemaPiece::Json) => {
111            let j = serde_json::from_str(s)?;
112            Ok(Value::Json(j))
113        }
114        (JsonValue::String(s), SchemaPiece::Uuid) => {
115            let u = uuid::Uuid::parse_str(s)?;
116            Ok(Value::Uuid(u))
117        }
118        (JsonValue::String(s), SchemaPiece::Enum { symbols, .. }) => {
119            if symbols.contains(s) {
120                Ok(Value::String(s.clone()))
121            } else {
122                bail!("Unrecognized enum variant: {}", s)
123            }
124        }
125        (JsonValue::Object(items), SchemaPiece::Record { .. }) => {
126            let mut builder = mz_avro::types::Record::new(schema)
127                .expect("`Record::new` should never fail if schema piece is a record!");
128            for (key, val) in items {
129                let field = builder
130                    .field_by_name(key)
131                    .ok_or_else(|| anyhow!("No such field {} in record: {}", key, val))?;
132                let val = from_json(val, schema.step(&field.schema))?;
133                builder.put(key, val);
134            }
135            Ok(builder.avro())
136        }
137        (JsonValue::Object(items), SchemaPiece::Map(m)) => {
138            let mut map = BTreeMap::new();
139            for (k, v) in items {
140                let (inner, name) = m.get_piece_and_name(schema.root);
141                map.insert(
142                    k.to_owned(),
143                    from_json(
144                        v,
145                        SchemaNode {
146                            root: schema.root,
147                            inner,
148                            name,
149                        },
150                    )?,
151                );
152            }
153            Ok(Value::Map(map))
154        }
155        (val, SchemaPiece::Union(us)) => {
156            let variants = us.variants();
157            let null_variant = variants
158                .iter()
159                .position(|v| v == &SchemaPieceOrNamed::Piece(SchemaPiece::Null));
160            if let JsonValue::Null = val {
161                return if let Some(nv) = null_variant {
162                    Ok(Value::Union {
163                        index: nv,
164                        inner: Box::new(Value::Null),
165                        n_variants: variants.len(),
166                        null_variant,
167                    })
168                } else {
169                    bail!("No `null` value in union schema.")
170                };
171            }
172            let items = match val {
173                JsonValue::Object(items) => items,
174                _ => bail!(
175                    "Union schema element must be `null` or a map from type name to value; found {:?}",
176                    val
177                ),
178            };
179            let (name, val) = if items.len() == 1 {
180                (items.keys().next().unwrap(), items.values().next().unwrap())
181            } else {
182                bail!(
183                    "Expected one-element object to match union schema: {:?} vs {:?}",
184                    json,
185                    schema
186                );
187            };
188            for (i, variant) in variants.iter().enumerate() {
189                let name_matches = match variant {
190                    SchemaPieceOrNamed::Piece(piece) => SchemaKind::from(piece).name() == name,
191                    SchemaPieceOrNamed::Named(idx) => {
192                        let schema_name = &schema.root.lookup(*idx).name;
193                        if name.chars().any(|ch| ch == '.') {
194                            name == &format!(
195                                "{}.{}",
196                                schema_name.namespace(),
197                                schema_name.base_name()
198                            )
199                        } else {
200                            name == schema_name.base_name()
201                        }
202                    }
203                };
204                if name_matches {
205                    match from_json(val, schema.step(variant)) {
206                        Ok(avro) => {
207                            return Ok(Value::Union {
208                                index: i,
209                                inner: Box::new(avro),
210                                n_variants: variants.len(),
211                                null_variant,
212                            });
213                        }
214                        Err(msg) => return Err(msg),
215                    }
216                }
217            }
218            bail!(
219                "Type not found in union: {}. variants: {:#?}",
220                name,
221                variants
222            )
223        }
224        _ => bail!(
225            "unable to match JSON value to schema: {:?} vs {:?}",
226            json,
227            schema
228        ),
229    }
230}
231
232/// Decodes an Avro datum from its Confluent-formatted byte representation.
233///
234/// The Confluent format includes a verbsion byte, followed by a 32-bit schema
235/// ID, followed by the encoded Avro value. This function validates the version
236/// byte but ignores the schema ID.
237pub fn from_confluent_bytes(schema: &Schema, mut bytes: &[u8]) -> Result<Value, anyhow::Error> {
238    if bytes.len() < 5 {
239        bail!(
240            "avro datum is too few bytes: expected at least 5 bytes, got {}",
241            bytes.len()
242        );
243    }
244    let magic = bytes[0];
245    let _schema_id = BigEndian::read_i32(&bytes[1..5]);
246    bytes = &bytes[5..];
247
248    if magic != 0 {
249        bail!(
250            "wrong avro serialization magic: expected 0, got {}",
251            bytes[0]
252        );
253    }
254
255    let datum = from_avro_datum(schema, &mut bytes).context("decoding avro datum")?;
256    Ok(datum)
257}
258
259/// A struct to enhance the debug output of various Avro types.
260///
261/// Testdrive scripts, for example, specify timestamps in micros, but debug
262/// output happens in Y-M-D format, which can be very difficult to map back to
263/// the correct input number. Similarly, dates are represented in Avro as
264/// `i32`s, but we would like to see the Y-M-D format as well.
265#[derive(Clone)]
266pub struct DebugValue(pub Value);
267
268impl Debug for DebugValue {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        match &self.0 {
271            Value::Timestamp(t) => write!(
272                f,
273                "Timestamp(\"{:?}\", {} micros, {} millis)",
274                t,
275                t.and_utc().timestamp_micros(),
276                t.and_utc().timestamp_millis()
277            ),
278            Value::Date(d) => write!(
279                f,
280                "Date({:?}, \"{}\")",
281                d,
282                NaiveDate::from_num_days_from_ce_opt(*d).unwrap()
283            ),
284
285            // Re-wrap types that contain a Value.
286            Value::Record(r) => f
287                .debug_set()
288                .entries(r.iter().map(|(s, v)| (s, DebugValue(v.clone()))))
289                .finish(),
290            Value::Array(a) => f
291                .debug_set()
292                .entries(a.iter().map(|v| DebugValue(v.clone())))
293                .finish(),
294            Value::Union {
295                index,
296                inner,
297                n_variants,
298                null_variant,
299            } => f
300                .debug_struct("Union")
301                .field("index", index)
302                .field("inner", &DebugValue(*inner.clone()))
303                .field("n_variants", n_variants)
304                .field("null_variant", null_variant)
305                .finish(),
306
307            _ => write!(f, "{:?}", self.0),
308        }
309    }
310}