mz_interchange/
protobuf.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11
12use anyhow::{Context, anyhow, bail};
13use mz_ore::str::StrExt;
14use mz_repr::{ColumnName, ColumnType, Datum, Row, RowPacker, ScalarType};
15use prost_reflect::{
16    Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor,
17    ReflectMessage, Value,
18};
19
20/// A decoded description of the schema of a Protobuf message.
21#[derive(Debug, PartialEq)]
22pub struct DecodedDescriptors {
23    message_descriptor: MessageDescriptor,
24    columns: Vec<(ColumnName, ColumnType)>,
25    message_name: String,
26}
27
28impl DecodedDescriptors {
29    /// Builds a `DecodedDescriptors` from an encoded `FileDescriptorSet` and
30    /// the fully qualified name of a message inside that file descriptor set.
31    pub fn from_bytes(bytes: &[u8], message_name: String) -> Result<Self, anyhow::Error> {
32        let fds = DescriptorPool::decode(bytes).context("decoding file descriptor set")?;
33        let message_descriptor = fds.get_message_by_name(&message_name).ok_or_else(|| {
34            anyhow!(
35                "protobuf message {} not found in file descriptor set",
36                message_name.quoted(),
37            )
38        })?;
39        let mut seen_messages = BTreeSet::new();
40        seen_messages.insert(message_descriptor.name().to_owned());
41        let mut columns = vec![];
42        for field in message_descriptor.fields() {
43            let name = ColumnName::from(field.name());
44            let ty = derive_column_type(&mut seen_messages, &field)?;
45            columns.push((name, ty))
46        }
47        Ok(DecodedDescriptors {
48            message_descriptor,
49            columns,
50            message_name,
51        })
52    }
53
54    /// Describes the columns in the message.
55    ///
56    /// In other words, the return value describes the shape of the rows that
57    /// will be produced by a [`Decoder`] constructed from this
58    /// `DecodedDescriptors`.
59    pub fn columns(&self) -> &[(ColumnName, ColumnType)] {
60        &self.columns
61    }
62}
63
64/// Decodes a particular Protobuf message from its wire format.
65#[derive(Debug)]
66pub struct Decoder {
67    descriptors: DecodedDescriptors,
68    row: Row,
69    confluent_wire_format: bool,
70}
71
72impl Decoder {
73    /// Constructs a decoder for a particular Protobuf message.
74    pub fn new(
75        descriptors: DecodedDescriptors,
76        confluent_wire_format: bool,
77    ) -> Result<Self, anyhow::Error> {
78        Ok(Decoder {
79            descriptors,
80            row: Row::default(),
81            confluent_wire_format,
82        })
83    }
84
85    /// Decodes the encoded Protobuf message into a [`Row`].
86    pub fn decode(&mut self, mut bytes: &[u8]) -> Result<Option<Row>, anyhow::Error> {
87        if self.confluent_wire_format {
88            // We support Protobuf schema evolution by ignoring the schema that
89            // the message was written with and attempting to decode into the
90            // schema we know about. As long as the new schema has been evolved
91            // according to the Protobuf evolution rules [0], this produces
92            // sensible and desirable results.
93            //
94            // There is the possibility that the message has been written with
95            // an incompatible schema, but this is relatively unlikely as the
96            // schema registry enforces compatible evolution by default. We
97            // don't bother to perform our own compatibility checks because the
98            // rules are complex and the Protobuf format is self-describing
99            // enough that decoding an Protobuf message with an incompatible
100            // schema is handled gracefully (e.g., no accidentally massive
101            // allocations).
102            //
103            // [0]: https://developers.google.com/protocol-buffers/docs/overview
104            let (_schema_id, adjusted_bytes) = crate::confluent::extract_protobuf_header(bytes)?;
105            bytes = adjusted_bytes;
106        }
107        let message = DynamicMessage::decode(self.descriptors.message_descriptor.clone(), bytes)?;
108        let mut packer = self.row.packer();
109        pack_message(&mut packer, &message)?;
110        Ok(Some(self.row.clone()))
111    }
112}
113
114fn derive_column_type(
115    seen_messages: &mut BTreeSet<String>,
116    field: &FieldDescriptor,
117) -> Result<ColumnType, anyhow::Error> {
118    if field.is_map() {
119        bail!("Protobuf map fields are not supported");
120    }
121
122    let ty = derive_inner_type(seen_messages, field.kind())?;
123    if field.is_list() {
124        Ok(ColumnType {
125            nullable: false,
126            scalar_type: ScalarType::List {
127                element_type: Box::new(ty.scalar_type),
128                custom_id: None,
129            },
130        })
131    } else {
132        Ok(ty)
133    }
134}
135
136fn derive_inner_type(
137    seen_messages: &mut BTreeSet<String>,
138    ty: Kind,
139) -> Result<ColumnType, anyhow::Error> {
140    match ty {
141        Kind::Bool => Ok(ScalarType::Bool.nullable(false)),
142        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Ok(ScalarType::Int32.nullable(false)),
143        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Ok(ScalarType::Int64.nullable(false)),
144        Kind::Uint32 | Kind::Fixed32 => Ok(ScalarType::UInt32.nullable(false)),
145        Kind::Uint64 | Kind::Fixed64 => Ok(ScalarType::UInt64.nullable(false)),
146        Kind::Float => Ok(ScalarType::Float32.nullable(false)),
147        Kind::Double => Ok(ScalarType::Float64.nullable(false)),
148        Kind::String => Ok(ScalarType::String.nullable(false)),
149        Kind::Bytes => Ok(ScalarType::Bytes.nullable(false)),
150        Kind::Enum(_) => Ok(ScalarType::String.nullable(false)),
151        Kind::Message(m) => {
152            if seen_messages.contains(m.name()) {
153                bail!("Recursive types are not supported: {}", m.name());
154            }
155            seen_messages.insert(m.name().to_owned());
156            let mut fields = Vec::with_capacity(m.fields().len());
157            for field in m.fields() {
158                let column_name = ColumnName::from(field.name());
159                let column_type = derive_column_type(seen_messages, &field)?;
160                fields.push((column_name, column_type))
161            }
162            seen_messages.remove(m.name());
163            let ty = ScalarType::Record {
164                fields: fields.into(),
165                custom_id: None,
166            };
167            Ok(ty.nullable(true))
168        }
169    }
170}
171
172fn pack_message(packer: &mut RowPacker, message: &DynamicMessage) -> Result<(), anyhow::Error> {
173    for field_desc in message.descriptor().fields() {
174        if !message.has_field(&field_desc) {
175            if field_desc.cardinality() == Cardinality::Required {
176                bail!(
177                    "protobuf message missing required field {}",
178                    field_desc.name()
179                );
180            }
181            if field_desc.kind().as_message().is_some() && !field_desc.is_list() {
182                packer.push(Datum::Null);
183                continue;
184            }
185        }
186        let value = message.get_field(&field_desc);
187        pack_value(packer, &field_desc, &*value)?;
188    }
189    Ok(())
190}
191
192fn pack_value(
193    packer: &mut RowPacker,
194    field_desc: &FieldDescriptor,
195    value: &Value,
196) -> Result<(), anyhow::Error> {
197    match value {
198        Value::Bool(false) => packer.push(Datum::False),
199        Value::Bool(true) => packer.push(Datum::True),
200        Value::I32(i) => packer.push(Datum::Int32(*i)),
201        Value::I64(i) => packer.push(Datum::Int64(*i)),
202        Value::U32(i) => packer.push(Datum::UInt32(*i)),
203        Value::U64(i) => packer.push(Datum::UInt64(*i)),
204        Value::F32(f) => packer.push(Datum::Float32((*f).into())),
205        Value::F64(f) => packer.push(Datum::Float64((*f).into())),
206        Value::String(s) => packer.push(Datum::String(s)),
207        Value::Bytes(s) => packer.push(Datum::Bytes(s)),
208        Value::EnumNumber(i) => {
209            let kind = field_desc.kind();
210            let enum_desc = kind.as_enum().ok_or_else(|| {
211                anyhow!(
212                    "internal error: decoding protobuf: field {} missing enum descriptor",
213                    field_desc.name()
214                )
215            })?;
216            let value = enum_desc.get_value(*i).ok_or_else(|| {
217                anyhow!(
218                    "error decoding protobuf: unknown enum value {} while decoding field {}",
219                    i,
220                    field_desc.name()
221                )
222            })?;
223            packer.push(Datum::String(value.name()));
224        }
225        Value::Message(m) => packer.push_list_with(|packer| pack_message(packer, m))?,
226        Value::List(values) => {
227            packer.push_list_with(|packer| {
228                for value in values {
229                    pack_value(packer, field_desc, value)?;
230                }
231                Ok::<_, anyhow::Error>(())
232            })?;
233        }
234        Value::Map(_) => bail!(
235            "internal error: unexpected value while decoding protobuf message: {:?}",
236            value
237        ),
238    }
239    Ok(())
240}