1use std::collections::BTreeSet;
11
12use anyhow::{Context, anyhow, bail};
13use mz_ore::str::StrExt;
14use mz_repr::{ColumnName, ColumnType, Datum, Row, RowPacker, ScalarType};
15use prost_reflect::{
16 Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor,
17 ReflectMessage, Value,
18};
19
20#[derive(Debug, PartialEq)]
22pub struct DecodedDescriptors {
23 message_descriptor: MessageDescriptor,
24 columns: Vec<(ColumnName, ColumnType)>,
25 message_name: String,
26}
27
28impl DecodedDescriptors {
29 pub fn from_bytes(bytes: &[u8], message_name: String) -> Result<Self, anyhow::Error> {
32 let fds = DescriptorPool::decode(bytes).context("decoding file descriptor set")?;
33 let message_descriptor = fds.get_message_by_name(&message_name).ok_or_else(|| {
34 anyhow!(
35 "protobuf message {} not found in file descriptor set",
36 message_name.quoted(),
37 )
38 })?;
39 let mut seen_messages = BTreeSet::new();
40 seen_messages.insert(message_descriptor.name().to_owned());
41 let mut columns = vec![];
42 for field in message_descriptor.fields() {
43 let name = ColumnName::from(field.name());
44 let ty = derive_column_type(&mut seen_messages, &field)?;
45 columns.push((name, ty))
46 }
47 Ok(DecodedDescriptors {
48 message_descriptor,
49 columns,
50 message_name,
51 })
52 }
53
54 pub fn columns(&self) -> &[(ColumnName, ColumnType)] {
60 &self.columns
61 }
62}
63
64#[derive(Debug)]
66pub struct Decoder {
67 descriptors: DecodedDescriptors,
68 row: Row,
69 confluent_wire_format: bool,
70}
71
72impl Decoder {
73 pub fn new(
75 descriptors: DecodedDescriptors,
76 confluent_wire_format: bool,
77 ) -> Result<Self, anyhow::Error> {
78 Ok(Decoder {
79 descriptors,
80 row: Row::default(),
81 confluent_wire_format,
82 })
83 }
84
85 pub fn decode(&mut self, mut bytes: &[u8]) -> Result<Option<Row>, anyhow::Error> {
87 if self.confluent_wire_format {
88 let (_schema_id, adjusted_bytes) = crate::confluent::extract_protobuf_header(bytes)?;
105 bytes = adjusted_bytes;
106 }
107 let message = DynamicMessage::decode(self.descriptors.message_descriptor.clone(), bytes)?;
108 let mut packer = self.row.packer();
109 pack_message(&mut packer, &message)?;
110 Ok(Some(self.row.clone()))
111 }
112}
113
114fn derive_column_type(
115 seen_messages: &mut BTreeSet<String>,
116 field: &FieldDescriptor,
117) -> Result<ColumnType, anyhow::Error> {
118 if field.is_map() {
119 bail!("Protobuf map fields are not supported");
120 }
121
122 let ty = derive_inner_type(seen_messages, field.kind())?;
123 if field.is_list() {
124 Ok(ColumnType {
125 nullable: false,
126 scalar_type: ScalarType::List {
127 element_type: Box::new(ty.scalar_type),
128 custom_id: None,
129 },
130 })
131 } else {
132 Ok(ty)
133 }
134}
135
136fn derive_inner_type(
137 seen_messages: &mut BTreeSet<String>,
138 ty: Kind,
139) -> Result<ColumnType, anyhow::Error> {
140 match ty {
141 Kind::Bool => Ok(ScalarType::Bool.nullable(false)),
142 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => Ok(ScalarType::Int32.nullable(false)),
143 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => Ok(ScalarType::Int64.nullable(false)),
144 Kind::Uint32 | Kind::Fixed32 => Ok(ScalarType::UInt32.nullable(false)),
145 Kind::Uint64 | Kind::Fixed64 => Ok(ScalarType::UInt64.nullable(false)),
146 Kind::Float => Ok(ScalarType::Float32.nullable(false)),
147 Kind::Double => Ok(ScalarType::Float64.nullable(false)),
148 Kind::String => Ok(ScalarType::String.nullable(false)),
149 Kind::Bytes => Ok(ScalarType::Bytes.nullable(false)),
150 Kind::Enum(_) => Ok(ScalarType::String.nullable(false)),
151 Kind::Message(m) => {
152 if seen_messages.contains(m.name()) {
153 bail!("Recursive types are not supported: {}", m.name());
154 }
155 seen_messages.insert(m.name().to_owned());
156 let mut fields = Vec::with_capacity(m.fields().len());
157 for field in m.fields() {
158 let column_name = ColumnName::from(field.name());
159 let column_type = derive_column_type(seen_messages, &field)?;
160 fields.push((column_name, column_type))
161 }
162 seen_messages.remove(m.name());
163 let ty = ScalarType::Record {
164 fields: fields.into(),
165 custom_id: None,
166 };
167 Ok(ty.nullable(true))
168 }
169 }
170}
171
172fn pack_message(packer: &mut RowPacker, message: &DynamicMessage) -> Result<(), anyhow::Error> {
173 for field_desc in message.descriptor().fields() {
174 if !message.has_field(&field_desc) {
175 if field_desc.cardinality() == Cardinality::Required {
176 bail!(
177 "protobuf message missing required field {}",
178 field_desc.name()
179 );
180 }
181 if field_desc.kind().as_message().is_some() && !field_desc.is_list() {
182 packer.push(Datum::Null);
183 continue;
184 }
185 }
186 let value = message.get_field(&field_desc);
187 pack_value(packer, &field_desc, &*value)?;
188 }
189 Ok(())
190}
191
192fn pack_value(
193 packer: &mut RowPacker,
194 field_desc: &FieldDescriptor,
195 value: &Value,
196) -> Result<(), anyhow::Error> {
197 match value {
198 Value::Bool(false) => packer.push(Datum::False),
199 Value::Bool(true) => packer.push(Datum::True),
200 Value::I32(i) => packer.push(Datum::Int32(*i)),
201 Value::I64(i) => packer.push(Datum::Int64(*i)),
202 Value::U32(i) => packer.push(Datum::UInt32(*i)),
203 Value::U64(i) => packer.push(Datum::UInt64(*i)),
204 Value::F32(f) => packer.push(Datum::Float32((*f).into())),
205 Value::F64(f) => packer.push(Datum::Float64((*f).into())),
206 Value::String(s) => packer.push(Datum::String(s)),
207 Value::Bytes(s) => packer.push(Datum::Bytes(s)),
208 Value::EnumNumber(i) => {
209 let kind = field_desc.kind();
210 let enum_desc = kind.as_enum().ok_or_else(|| {
211 anyhow!(
212 "internal error: decoding protobuf: field {} missing enum descriptor",
213 field_desc.name()
214 )
215 })?;
216 let value = enum_desc.get_value(*i).ok_or_else(|| {
217 anyhow!(
218 "error decoding protobuf: unknown enum value {} while decoding field {}",
219 i,
220 field_desc.name()
221 )
222 })?;
223 packer.push(Datum::String(value.name()));
224 }
225 Value::Message(m) => packer.push_list_with(|packer| pack_message(packer, m))?,
226 Value::List(values) => {
227 packer.push_list_with(|packer| {
228 for value in values {
229 pack_value(packer, field_desc, value)?;
230 }
231 Ok::<_, anyhow::Error>(())
232 })?;
233 }
234 Value::Map(_) => bail!(
235 "internal error: unexpected value while decoding protobuf message: {:?}",
236 value
237 ),
238 }
239 Ok(())
240}