mz_persist_types/
arrow.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A [protobuf] representation of [Apache Arrow] arrays.
11//!
12//! # Motivation
13//!
14//! Persist can store a small amount of data inline at the consensus layer.
15//! Because we are space constrained, we take particular care to store only the
16//! data that is necessary. Other Arrow serialization formats, e.g. [Parquet]
17//! or [Arrow IPC], include data that we don't need and would be wasteful to
18//! store.
19//!
20//! [protobuf]: https://protobuf.dev/
21//! [Apache Arrow]: https://arrow.apache.org/
22//! [Parquet]: https://parquet.apache.org/docs/
23//! [Arrow IPC]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
24
25use std::cmp::Ordering;
26use std::fmt::{Debug, Display};
27use std::sync::Arc;
28
29use arrow::array::*;
30use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
31use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Fields};
32use itertools::Itertools;
33use mz_ore::cast::CastFrom;
34use mz_ore::soft_assert_eq_no_log;
35use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
36use prost::Message;
37
38#[allow(missing_docs)]
39mod proto {
40    include!(concat!(env!("OUT_DIR"), "/mz_persist_types.arrow.rs"));
41}
42use crate::arrow::proto::data_type;
43pub use proto::{DataType as ProtoDataType, ProtoArrayData};
44
45/// Extract the list of fields for our recursive datatypes.
46pub fn fields_for_type(data_type: &DataType) -> &[FieldRef] {
47    match data_type {
48        DataType::Struct(fields) => fields,
49        DataType::List(field) => std::slice::from_ref(field),
50        DataType::Map(field, _) => std::slice::from_ref(field),
51        DataType::Null
52        | DataType::Boolean
53        | DataType::Int8
54        | DataType::Int16
55        | DataType::Int32
56        | DataType::Int64
57        | DataType::UInt8
58        | DataType::UInt16
59        | DataType::UInt32
60        | DataType::UInt64
61        | DataType::Float16
62        | DataType::Float32
63        | DataType::Float64
64        | DataType::Timestamp(_, _)
65        | DataType::Date32
66        | DataType::Date64
67        | DataType::Time32(_)
68        | DataType::Time64(_)
69        | DataType::Duration(_)
70        | DataType::Interval(_)
71        | DataType::Binary
72        | DataType::FixedSizeBinary(_)
73        | DataType::LargeBinary
74        | DataType::BinaryView
75        | DataType::Utf8
76        | DataType::LargeUtf8
77        | DataType::Utf8View
78        | DataType::Decimal32(_, _)
79        | DataType::Decimal64(_, _)
80        | DataType::Decimal128(_, _)
81        | DataType::Decimal256(_, _) => &[],
82        DataType::ListView(_)
83        | DataType::FixedSizeList(_, _)
84        | DataType::LargeList(_)
85        | DataType::LargeListView(_)
86        | DataType::Union(_, _)
87        | DataType::Dictionary(_, _)
88        | DataType::RunEndEncoded(_, _) => unimplemented!("not supported"),
89    }
90}
91
92/// Encode the array into proto. If an expected data type is passed, that implies it is
93/// encoded at some higher level, and we omit it from the data.
94fn into_proto_with_type(data: &ArrayData, expected_type: Option<&DataType>) -> ProtoArrayData {
95    let data_type = match expected_type {
96        Some(expected) => {
97            // Equality is recursive, and this function is itself called recursively,
98            // skip the call in production to avoid a quadratic overhead.
99            soft_assert_eq_no_log!(
100                expected,
101                data.data_type(),
102                "actual type should match expected type"
103            );
104            None
105        }
106        None => Some(data.data_type().into_proto()),
107    };
108
109    ProtoArrayData {
110        data_type,
111        length: u64::cast_from(data.len()),
112        offset: u64::cast_from(data.offset()),
113        buffers: data.buffers().iter().map(|b| b.into_proto()).collect(),
114        children: data
115            .child_data()
116            .iter()
117            .zip_eq(fields_for_type(
118                expected_type.unwrap_or_else(|| data.data_type()),
119            ))
120            .map(|(child, expect)| into_proto_with_type(child, Some(expect.data_type())))
121            .collect(),
122        nulls: data.nulls().map(|n| n.inner().into_proto()),
123    }
124}
125
126/// Decode the array data.
127/// If the data type is omitted from the proto, we decode it as the expected type.
128fn from_proto_with_type(
129    proto: ProtoArrayData,
130    expected_type: Option<&DataType>,
131) -> Result<ArrayData, TryFromProtoError> {
132    let ProtoArrayData {
133        data_type,
134        length,
135        offset,
136        buffers,
137        children,
138        nulls,
139    } = proto;
140    let data_type: Option<DataType> = data_type.into_rust()?;
141    let data_type = match (data_type, expected_type) {
142        (Some(data_type), None) => data_type,
143        (Some(data_type), Some(expected_type)) => {
144            // Equality is recursive, and this function is itself called recursively,
145            // skip the call in production to avoid a quadratic overhead.
146            soft_assert_eq_no_log!(
147                data_type,
148                *expected_type,
149                "expected type should match actual type"
150            );
151            data_type
152        }
153        (None, Some(expected_type)) => expected_type.clone(),
154        (None, None) => {
155            return Err(TryFromProtoError::MissingField(
156                "ProtoArrayData::data_type".to_string(),
157            ));
158        }
159    };
160    let nulls = nulls
161        .map(|n| n.into_rust())
162        .transpose()?
163        .map(NullBuffer::new);
164
165    let mut builder = ArrayDataBuilder::new(data_type.clone())
166        .len(usize::cast_from(length))
167        .offset(usize::cast_from(offset))
168        .nulls(nulls);
169
170    for b in buffers.into_iter().map(|b| b.into_rust()) {
171        builder = builder.add_buffer(b?);
172    }
173    for c in children
174        .into_iter()
175        .zip_eq(fields_for_type(&data_type))
176        .map(|(c, field)| from_proto_with_type(c, Some(field.data_type())))
177    {
178        builder = builder.add_child_data(c?);
179    }
180
181    // Construct the builder which validates all inputs and aligns data.
182    builder
183        .align_buffers(true)
184        .build()
185        .map_err(|e| TryFromProtoError::RowConversionError(e.to_string()))
186}
187
188impl RustType<ProtoArrayData> for arrow::array::ArrayData {
189    fn into_proto(&self) -> ProtoArrayData {
190        into_proto_with_type(self, None)
191    }
192
193    fn from_proto(proto: ProtoArrayData) -> Result<Self, TryFromProtoError> {
194        from_proto_with_type(proto, None)
195    }
196}
197
198impl RustType<proto::DataType> for arrow::datatypes::DataType {
199    fn into_proto(&self) -> proto::DataType {
200        let kind = match self {
201            DataType::Null => proto::data_type::Kind::Null(()),
202            DataType::Boolean => proto::data_type::Kind::Boolean(()),
203            DataType::UInt8 => proto::data_type::Kind::Uint8(()),
204            DataType::UInt16 => proto::data_type::Kind::Uint16(()),
205            DataType::UInt32 => proto::data_type::Kind::Uint32(()),
206            DataType::UInt64 => proto::data_type::Kind::Uint64(()),
207            DataType::Int8 => proto::data_type::Kind::Int8(()),
208            DataType::Int16 => proto::data_type::Kind::Int16(()),
209            DataType::Int32 => proto::data_type::Kind::Int32(()),
210            DataType::Int64 => proto::data_type::Kind::Int64(()),
211            DataType::Float32 => proto::data_type::Kind::Float32(()),
212            DataType::Float64 => proto::data_type::Kind::Float64(()),
213            DataType::Utf8 => proto::data_type::Kind::String(()),
214            DataType::Binary => proto::data_type::Kind::Binary(()),
215            DataType::FixedSizeBinary(size) => proto::data_type::Kind::FixedBinary(*size),
216            DataType::List(inner) => proto::data_type::Kind::List(Box::new(inner.into_proto())),
217            DataType::Map(inner, sorted) => {
218                let map = proto::data_type::Map {
219                    value: Some(Box::new(inner.into_proto())),
220                    sorted: *sorted,
221                };
222                proto::data_type::Kind::Map(Box::new(map))
223            }
224            DataType::Struct(children) => {
225                let children = children.into_iter().map(|f| f.into_proto()).collect();
226                proto::data_type::Kind::Struct(proto::data_type::Struct { children })
227            }
228            other => unimplemented!("unsupported data type {other:?}"),
229        };
230
231        proto::DataType { kind: Some(kind) }
232    }
233
234    fn from_proto(proto: proto::DataType) -> Result<Self, TryFromProtoError> {
235        let data_type = proto
236            .kind
237            .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
238        let data_type = match data_type {
239            proto::data_type::Kind::Null(()) => DataType::Null,
240            proto::data_type::Kind::Boolean(()) => DataType::Boolean,
241            proto::data_type::Kind::Uint8(()) => DataType::UInt8,
242            proto::data_type::Kind::Uint16(()) => DataType::UInt16,
243            proto::data_type::Kind::Uint32(()) => DataType::UInt32,
244            proto::data_type::Kind::Uint64(()) => DataType::UInt64,
245            proto::data_type::Kind::Int8(()) => DataType::Int8,
246            proto::data_type::Kind::Int16(()) => DataType::Int16,
247            proto::data_type::Kind::Int32(()) => DataType::Int32,
248            proto::data_type::Kind::Int64(()) => DataType::Int64,
249            proto::data_type::Kind::Float32(()) => DataType::Float32,
250            proto::data_type::Kind::Float64(()) => DataType::Float64,
251            proto::data_type::Kind::String(()) => DataType::Utf8,
252            proto::data_type::Kind::Binary(()) => DataType::Binary,
253            proto::data_type::Kind::FixedBinary(size) => DataType::FixedSizeBinary(size),
254            proto::data_type::Kind::List(inner) => DataType::List(Arc::new((*inner).into_rust()?)),
255            proto::data_type::Kind::Map(inner) => {
256                let value = inner
257                    .value
258                    .ok_or_else(|| TryFromProtoError::missing_field("map.value"))?;
259                DataType::Map(Arc::new((*value).into_rust()?), inner.sorted)
260            }
261            proto::data_type::Kind::Struct(inner) => {
262                let children: Vec<Field> = inner
263                    .children
264                    .into_iter()
265                    .map(|c| c.into_rust())
266                    .collect::<Result<_, _>>()?;
267                DataType::Struct(Fields::from(children))
268            }
269        };
270
271        Ok(data_type)
272    }
273}
274
275impl RustType<proto::Field> for arrow::datatypes::Field {
276    fn into_proto(&self) -> proto::Field {
277        proto::Field {
278            name: self.name().clone(),
279            nullable: self.is_nullable(),
280            data_type: Some(Box::new(self.data_type().into_proto())),
281        }
282    }
283
284    fn from_proto(proto: proto::Field) -> Result<Self, TryFromProtoError> {
285        let proto::Field {
286            name,
287            nullable,
288            data_type,
289        } = proto;
290        let data_type =
291            data_type.ok_or_else(|| TryFromProtoError::missing_field("field.data_type"))?;
292        let data_type = (*data_type).into_rust()?;
293
294        Ok(Field::new(name, data_type, nullable))
295    }
296}
297
298impl RustType<proto::Buffer> for arrow::buffer::Buffer {
299    fn into_proto(&self) -> proto::Buffer {
300        // Wrapping since arrow's buffer doesn't implement AsRef, though the deref impl exists.
301        #[repr(transparent)]
302        struct BufferWrapper(arrow::buffer::Buffer);
303        impl AsRef<[u8]> for BufferWrapper {
304            fn as_ref(&self) -> &[u8] {
305                &*self.0
306            }
307        }
308        proto::Buffer {
309            data: bytes::Bytes::from_owner(BufferWrapper(self.clone())),
310        }
311    }
312
313    fn from_proto(proto: proto::Buffer) -> Result<Self, TryFromProtoError> {
314        Ok(arrow::buffer::Buffer::from(proto.data))
315    }
316}
317
318impl RustType<proto::BooleanBuffer> for arrow::buffer::BooleanBuffer {
319    fn into_proto(&self) -> proto::BooleanBuffer {
320        proto::BooleanBuffer {
321            buffer: Some(self.sliced().into_proto()),
322            length: u64::cast_from(self.len()),
323        }
324    }
325
326    fn from_proto(proto: proto::BooleanBuffer) -> Result<Self, TryFromProtoError> {
327        let proto::BooleanBuffer { buffer, length } = proto;
328        let buffer = buffer.into_rust_if_some("buffer")?;
329        Ok(BooleanBuffer::new(buffer, 0, usize::cast_from(length)))
330    }
331}
332
333/// Wraps a single arrow array, downcasted to a specific type.
334#[derive(Clone)]
335pub enum ArrayOrd {
336    /// Wraps a `NullArray`.
337    Null(NullArray),
338    /// Wraps a `Bool` array.
339    Bool(BooleanArray),
340    /// Wraps a `Int8` array.
341    Int8(Int8Array),
342    /// Wraps a `Int16` array.
343    Int16(Int16Array),
344    /// Wraps a `Int32` array.
345    Int32(Int32Array),
346    /// Wraps a `Int64` array.
347    Int64(Int64Array),
348    /// Wraps a `UInt8` array.
349    UInt8(UInt8Array),
350    /// Wraps a `UInt16` array.
351    UInt16(UInt16Array),
352    /// Wraps a `UInt32` array.
353    UInt32(UInt32Array),
354    /// Wraps a `UInt64` array.
355    UInt64(UInt64Array),
356    /// Wraps a `Float32` array.
357    Float32(Float32Array),
358    /// Wraps a `Float64` array.
359    Float64(Float64Array),
360    /// Wraps a `String` array.
361    String(StringArray),
362    /// Wraps a `Binary` array.
363    Binary(BinaryArray),
364    /// Wraps a `FixedSizeBinary` array.
365    FixedSizeBinary(FixedSizeBinaryArray),
366    /// Wraps a `List` array.
367    List(Option<NullBuffer>, OffsetBuffer<i32>, Box<ArrayOrd>),
368    /// Wraps a `Struct` array.
369    Struct(Option<NullBuffer>, Vec<ArrayOrd>),
370}
371
372impl ArrayOrd {
373    /// Downcast the provided array to a specific type in our enum.
374    pub fn new(array: &dyn Array) -> Self {
375        match array.data_type() {
376            DataType::Null => ArrayOrd::Null(NullArray::from(array.to_data())),
377            DataType::Boolean => ArrayOrd::Bool(array.as_boolean().clone()),
378            DataType::Int8 => ArrayOrd::Int8(array.as_primitive().clone()),
379            DataType::Int16 => ArrayOrd::Int16(array.as_primitive().clone()),
380            DataType::Int32 => ArrayOrd::Int32(array.as_primitive().clone()),
381            DataType::Int64 => ArrayOrd::Int64(array.as_primitive().clone()),
382            DataType::UInt8 => ArrayOrd::UInt8(array.as_primitive().clone()),
383            DataType::UInt16 => ArrayOrd::UInt16(array.as_primitive().clone()),
384            DataType::UInt32 => ArrayOrd::UInt32(array.as_primitive().clone()),
385            DataType::UInt64 => ArrayOrd::UInt64(array.as_primitive().clone()),
386            DataType::Float32 => ArrayOrd::Float32(array.as_primitive().clone()),
387            DataType::Float64 => ArrayOrd::Float64(array.as_primitive().clone()),
388            DataType::Binary => ArrayOrd::Binary(array.as_binary().clone()),
389            DataType::Utf8 => ArrayOrd::String(array.as_string().clone()),
390            DataType::FixedSizeBinary(_) => {
391                ArrayOrd::FixedSizeBinary(array.as_fixed_size_binary().clone())
392            }
393            DataType::List(_) => {
394                let list_array = array.as_list();
395                ArrayOrd::List(
396                    list_array.nulls().cloned(),
397                    list_array.offsets().clone(),
398                    Box::new(ArrayOrd::new(list_array.values())),
399                )
400            }
401            DataType::Struct(_) => {
402                let struct_array = array.as_struct();
403                let nulls = array.nulls().cloned();
404                let columns: Vec<_> = struct_array
405                    .columns()
406                    .iter()
407                    .map(|a| ArrayOrd::new(a))
408                    .collect();
409                ArrayOrd::Struct(nulls, columns)
410            }
411            data_type => unimplemented!("array type {data_type:?} not yet supported"),
412        }
413    }
414
415    /// Returns the rough amount of space required for the data in this array in bytes.
416    /// (Not counting nulls, dictionary encoding, or other space optimizations.)
417    pub fn goodbytes(&self) -> usize {
418        match self {
419            ArrayOrd::Null(_) => 0,
420            // This is, strictly speaking, wrong - but consistent with `ArrayIdx::goodbytes`,
421            // which counts one byte per bool.
422            ArrayOrd::Bool(b) => b.len(),
423            ArrayOrd::Int8(a) => a.values().inner().len(),
424            ArrayOrd::Int16(a) => a.values().inner().len(),
425            ArrayOrd::Int32(a) => a.values().inner().len(),
426            ArrayOrd::Int64(a) => a.values().inner().len(),
427            ArrayOrd::UInt8(a) => a.values().inner().len(),
428            ArrayOrd::UInt16(a) => a.values().inner().len(),
429            ArrayOrd::UInt32(a) => a.values().inner().len(),
430            ArrayOrd::UInt64(a) => a.values().inner().len(),
431            ArrayOrd::Float32(a) => a.values().inner().len(),
432            ArrayOrd::Float64(a) => a.values().inner().len(),
433            ArrayOrd::String(a) => a.values().len(),
434            ArrayOrd::Binary(a) => a.values().len(),
435            ArrayOrd::FixedSizeBinary(a) => a.values().len(),
436            ArrayOrd::List(_, _, nested) => nested.goodbytes(),
437            ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.goodbytes()).sum(),
438        }
439    }
440
441    /// Return a struct representing the value at a particular index in this array.
442    pub fn at(&self, idx: usize) -> ArrayIdx<'_> {
443        ArrayIdx { idx, array: self }
444    }
445}
446
447impl Debug for ArrayOrd {
448    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449        struct DebugType<'a>(&'a ArrayOrd);
450
451        impl Debug for DebugType<'_> {
452            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453                match self.0 {
454                    ArrayOrd::Null(_) => write!(f, "Null"),
455                    ArrayOrd::Bool(_) => write!(f, "Bool"),
456                    ArrayOrd::Int8(_) => write!(f, "Int8"),
457                    ArrayOrd::Int16(_) => write!(f, "Int16"),
458                    ArrayOrd::Int32(_) => write!(f, "Int32"),
459                    ArrayOrd::Int64(_) => write!(f, "Int64"),
460                    ArrayOrd::UInt8(_) => write!(f, "UInt8"),
461                    ArrayOrd::UInt16(_) => write!(f, "UInt16"),
462                    ArrayOrd::UInt32(_) => write!(f, "UInt32"),
463                    ArrayOrd::UInt64(_) => write!(f, "UInt64"),
464                    ArrayOrd::Float32(_) => write!(f, "Float32"),
465                    ArrayOrd::Float64(_) => write!(f, "Float64"),
466                    ArrayOrd::String(_) => write!(f, "String"),
467                    ArrayOrd::Binary(_) => write!(f, "Binary"),
468                    ArrayOrd::FixedSizeBinary(a) => f
469                        .debug_tuple("FixedSizeBinary")
470                        .field(&a.value_length())
471                        .finish(),
472                    ArrayOrd::List(_, _, nested) => f.debug_tuple("List").field(&*nested).finish(),
473                    ArrayOrd::Struct(_, fields) => {
474                        let mut tuple = f.debug_tuple("Struct");
475                        for field in fields {
476                            tuple.field(field);
477                        }
478                        tuple.finish()
479                    }
480                }
481            }
482        }
483
484        f.debug_struct("ArrayOrd")
485            .field("type", &DebugType(self))
486            .field("goodbytes", &self.goodbytes())
487            .finish()
488    }
489}
490
491/// A struct representing a particular entry in a particular array. Most useful for its `Ord`
492/// implementation, which can compare entire rows across similarly-typed arrays.
493///
494/// It is an error to compare indices from arrays with different types, with one exception:
495/// it is valid to compare two `StructArray`s, one of which is a prefix of the other...
496/// in which case we'll compare the values on that subset of the fields, and the shorter
497/// of the two structs will compare less if they're otherwise equal.
498#[derive(Clone, Copy, Debug)]
499pub struct ArrayIdx<'a> {
500    /// An index into a particular array.
501    pub idx: usize,
502    /// The particular array.
503    pub array: &'a ArrayOrd,
504}
505
506impl Display for ArrayIdx<'_> {
507    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508        match self.array {
509            ArrayOrd::Null(_) => write!(f, "null"),
510            ArrayOrd::Bool(a) => write!(f, "{}", a.value(self.idx)),
511            ArrayOrd::Int8(a) => write!(f, "{}", a.value(self.idx)),
512            ArrayOrd::Int16(a) => write!(f, "{}", a.value(self.idx)),
513            ArrayOrd::Int32(a) => write!(f, "{}", a.value(self.idx)),
514            ArrayOrd::Int64(a) => write!(f, "{}", a.value(self.idx)),
515            ArrayOrd::UInt8(a) => write!(f, "{}", a.value(self.idx)),
516            ArrayOrd::UInt16(a) => write!(f, "{}", a.value(self.idx)),
517            ArrayOrd::UInt32(a) => write!(f, "{}", a.value(self.idx)),
518            ArrayOrd::UInt64(a) => write!(f, "{}", a.value(self.idx)),
519            ArrayOrd::Float32(a) => write!(f, "{}", a.value(self.idx)),
520            ArrayOrd::Float64(a) => write!(f, "{}", a.value(self.idx)),
521            ArrayOrd::String(a) => write!(f, "{}", a.value(self.idx)),
522            ArrayOrd::Binary(a) => {
523                for byte in a.value(self.idx) {
524                    write!(f, "{:02x}", byte)?;
525                }
526                Ok(())
527            }
528            ArrayOrd::FixedSizeBinary(a) => {
529                for byte in a.value(self.idx) {
530                    write!(f, "{:02x}", byte)?;
531                }
532                Ok(())
533            }
534            ArrayOrd::List(_, offsets, nested) => {
535                write!(
536                    f,
537                    "[{}]",
538                    mz_ore::str::separated(", ", list_range(offsets, nested, self.idx))
539                )
540            }
541            ArrayOrd::Struct(_, nested) => write!(
542                f,
543                "{{{}}}",
544                mz_ore::str::separated(", ", nested.iter().map(|f| f.at(self.idx)))
545            ),
546        }
547    }
548}
549
550#[inline]
551fn list_range<'a>(
552    offsets: &OffsetBuffer<i32>,
553    values: &'a ArrayOrd,
554    idx: usize,
555) -> impl Iterator<Item = ArrayIdx<'a>> + Clone {
556    let offsets = offsets.inner();
557    let from = offsets[idx].as_usize();
558    let to = offsets[idx + 1].as_usize();
559    (from..to).map(|i| values.at(i))
560}
561
562impl<'a> ArrayIdx<'a> {
563    /// Returns the rough amount of space required for this entry in bytes.
564    /// (Not counting nulls, dictionary encoding, or other space optimizations.)
565    pub fn goodbytes(&self) -> usize {
566        match self.array {
567            ArrayOrd::Null(_) => 0,
568            ArrayOrd::Bool(_) => size_of::<bool>(),
569            ArrayOrd::Int8(_) => size_of::<i8>(),
570            ArrayOrd::Int16(_) => size_of::<i16>(),
571            ArrayOrd::Int32(_) => size_of::<i32>(),
572            ArrayOrd::Int64(_) => size_of::<i64>(),
573            ArrayOrd::UInt8(_) => size_of::<u8>(),
574            ArrayOrd::UInt16(_) => size_of::<u16>(),
575            ArrayOrd::UInt32(_) => size_of::<u32>(),
576            ArrayOrd::UInt64(_) => size_of::<u64>(),
577            ArrayOrd::Float32(_) => size_of::<f32>(),
578            ArrayOrd::Float64(_) => size_of::<f64>(),
579            ArrayOrd::String(a) => a.value(self.idx).len(),
580            ArrayOrd::Binary(a) => a.value(self.idx).len(),
581            ArrayOrd::FixedSizeBinary(a) => a.value_length().as_usize(),
582            ArrayOrd::List(_, offsets, nested) => {
583                // Range over the list, summing up the bytes for each entry.
584                list_range(offsets, nested, self.idx)
585                    .map(|a| a.goodbytes())
586                    .sum()
587            }
588            ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.at(self.idx).goodbytes()).sum(),
589        }
590    }
591}
592
593impl<'a> Ord for ArrayIdx<'a> {
594    fn cmp(&self, other: &Self) -> Ordering {
595        #[inline]
596        fn is_null(buffer: &Option<NullBuffer>, idx: usize) -> bool {
597            buffer.as_ref().map_or(false, |b| b.is_null(idx))
598        }
599        #[inline]
600        fn cmp<A: ArrayAccessor>(
601            left: A,
602            left_idx: usize,
603            right: A,
604            right_idx: usize,
605            cmp: fn(&A::Item, &A::Item) -> Ordering,
606        ) -> Ordering {
607            // NB: nulls sort last, conveniently matching psql / mz_repr
608            match (left.is_null(left_idx), right.is_null(right_idx)) {
609                (false, true) => Ordering::Less,
610                (true, true) => Ordering::Equal,
611                (true, false) => Ordering::Greater,
612                (false, false) => cmp(&left.value(left_idx), &right.value(right_idx)),
613            }
614        }
615        match (&self.array, &other.array) {
616            (ArrayOrd::Null(s), ArrayOrd::Null(o)) => {
617                debug_assert!(
618                    self.idx < s.len() && other.idx < o.len(),
619                    "null array indices in bounds"
620                );
621                Ordering::Equal
622            }
623            // For arrays with "simple" value types, we fetch and compare the underlying values directly.
624            (ArrayOrd::Bool(s), ArrayOrd::Bool(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
625            (ArrayOrd::Int8(s), ArrayOrd::Int8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
626            (ArrayOrd::Int16(s), ArrayOrd::Int16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
627            (ArrayOrd::Int32(s), ArrayOrd::Int32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
628            (ArrayOrd::Int64(s), ArrayOrd::Int64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
629            (ArrayOrd::UInt8(s), ArrayOrd::UInt8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
630            (ArrayOrd::UInt16(s), ArrayOrd::UInt16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
631            (ArrayOrd::UInt32(s), ArrayOrd::UInt32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
632            (ArrayOrd::UInt64(s), ArrayOrd::UInt64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
633            (ArrayOrd::Float32(s), ArrayOrd::Float32(o)) => {
634                cmp(s, self.idx, o, other.idx, f32::total_cmp)
635            }
636            (ArrayOrd::Float64(s), ArrayOrd::Float64(o)) => {
637                cmp(s, self.idx, o, other.idx, f64::total_cmp)
638            }
639            (ArrayOrd::String(s), ArrayOrd::String(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
640            (ArrayOrd::Binary(s), ArrayOrd::Binary(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
641            (ArrayOrd::FixedSizeBinary(s), ArrayOrd::FixedSizeBinary(o)) => {
642                cmp(s, self.idx, o, other.idx, Ord::cmp)
643            }
644            // For lists, we generate an iterator for each side that ranges over the correct
645            // indices into the value buffer, then compare them lexicographically.
646            (
647                ArrayOrd::List(s_nulls, s_offset, s_values),
648                ArrayOrd::List(o_nulls, o_offset, o_values),
649            ) => match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
650                (false, true) => Ordering::Less,
651                (true, true) => Ordering::Equal,
652                (true, false) => Ordering::Greater,
653                (false, false) => list_range(s_offset, s_values, self.idx)
654                    .cmp(list_range(o_offset, o_values, other.idx)),
655            },
656            // For structs, we iterate over the same index in each field for each input,
657            // comparing them lexicographically in order.
658            (ArrayOrd::Struct(s_nulls, s_cols), ArrayOrd::Struct(o_nulls, o_cols)) => {
659                match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
660                    (false, true) => Ordering::Less,
661                    (true, true) => Ordering::Equal,
662                    (true, false) => Ordering::Greater,
663                    (false, false) => {
664                        let s = s_cols.iter().map(|array| array.at(self.idx));
665                        let o = o_cols.iter().map(|array| array.at(other.idx));
666                        s.cmp(o)
667                    }
668                }
669            }
670            (a, b) => panic!("array types did not match! {a:?} vs. {b:?}",),
671        }
672    }
673}
674
675impl<'a> PartialOrd for ArrayIdx<'a> {
676    fn partial_cmp(&self, other: &ArrayIdx) -> Option<Ordering> {
677        Some(self.cmp(other))
678    }
679}
680
681impl<'a> PartialEq for ArrayIdx<'a> {
682    fn eq(&self, other: &ArrayIdx) -> bool {
683        self.cmp(other) == Ordering::Equal
684    }
685}
686
687impl<'a> Eq for ArrayIdx<'a> {}
688
689/// An array with precisely one entry, for use as a lower bound.
690#[derive(Debug, Clone)]
691pub struct ArrayBound {
692    raw: ArrayRef,
693    ord: ArrayOrd,
694    index: usize,
695}
696
697impl PartialEq for ArrayBound {
698    fn eq(&self, other: &Self) -> bool {
699        self.get().eq(&other.get())
700    }
701}
702
703impl Eq for ArrayBound {}
704
705impl ArrayBound {
706    /// Create a new `ArrayBound` for this array, with the bound at the provided index.
707    pub fn new(array: ArrayRef, index: usize) -> Self {
708        Self {
709            ord: ArrayOrd::new(array.as_ref()),
710            raw: array,
711            index,
712        }
713    }
714
715    /// Get the value of the bound.
716    pub fn get(&self) -> ArrayIdx<'_> {
717        self.ord.at(self.index)
718    }
719
720    /// Convert to an array-data proto, respecting a maximum data size. The resulting proto will
721    /// decode to a single-row array, such that `ArrayBound::new(decoded, 0).get() <= self.get()`,
722    /// which makes it suitable as a lower bound.
723    pub fn to_proto_lower(&self, max_len: usize) -> Option<ProtoArrayData> {
724        // Use `take` instead of slice to make sure we encode just the relevant row to proto,
725        // instead of some larger buffer with an offset.
726        let indices = UInt64Array::from_value(u64::usize_as(self.index), 1);
727        let taken = arrow::compute::take(self.raw.as_ref(), &indices, None).ok()?;
728        let array_data = taken.into_data();
729
730        let mut proto = array_data.into_proto();
731        let original_len = proto.encoded_len();
732        if original_len <= max_len {
733            return Some(proto);
734        }
735
736        let mut data_type = proto.data_type.take()?;
737        maybe_trim_proto(&mut data_type, &mut proto, max_len);
738        proto.data_type = Some(data_type);
739
740        if cfg!(debug_assertions) {
741            let array: ArrayData = proto
742                .clone()
743                .into_rust()
744                .expect("trimmed array data can still be decoded");
745            assert_eq!(array.len(), 1);
746            let new_bound = Self::new(make_array(array), 0);
747            assert!(
748                new_bound.get() <= self.get(),
749                "trimmed bound should be comparable to and no larger than the original data"
750            )
751        }
752
753        if proto.encoded_len() <= max_len {
754            Some(proto)
755        } else {
756            None
757        }
758    }
759}
760
761/// Makes a best effort to shrink the proto while preserving the ordering.
762/// (The proto might not be smaller after this method is called, but it should always
763/// be a valid lower bound.)
764///
765/// Note that we pass in the data type and the array data separately, since we only keep
766/// type info at the top level. If a caller does have a top-level `ArrayData` instance,
767/// they should take that type and pass it in separately.
768fn maybe_trim_proto(data_type: &mut proto::DataType, body: &mut ProtoArrayData, max_len: usize) {
769    assert!(body.data_type.is_none(), "expected separate data type");
770    // TODO: consider adding cases for strings and byte arrays
771    let encoded_len = data_type.encoded_len() + body.encoded_len();
772    match &mut data_type.kind {
773        Some(data_type::Kind::Struct(data_type::Struct { children: fields })) => {
774            // Pop off fields one by one, keeping an estimate of the encoded length.
775            let mut struct_len = encoded_len;
776            while struct_len > max_len {
777                let Some(mut child) = body.children.pop() else {
778                    break;
779                };
780                let Some(mut field) = fields.pop() else { break };
781
782                struct_len -= field.encoded_len() + child.encoded_len();
783                if let Some(remaining_len) = max_len.checked_sub(struct_len) {
784                    // We're under budget after removing this field! See if we can
785                    // shrink it to fit, but exit the loop regardless.
786                    let Some(field_type) = field.data_type.as_mut() else {
787                        break;
788                    };
789                    maybe_trim_proto(field_type, &mut child, remaining_len);
790                    if field.encoded_len() + child.encoded_len() <= remaining_len {
791                        fields.push(field);
792                        body.children.push(child);
793                    }
794                    break;
795                }
796            }
797        }
798        _ => {}
799    };
800}
801
802#[cfg(test)]
803mod tests {
804    use crate::arrow::{ArrayBound, ArrayOrd};
805    use arrow::array::{
806        ArrayRef, AsArray, BooleanArray, StringArray, StructArray, UInt64Array, make_array,
807    };
808    use arrow::datatypes::{DataType, Field, Fields};
809    use mz_ore::assert_none;
810    use mz_proto::ProtoType;
811    use std::sync::Arc;
812
813    #[mz_ore::test]
814    fn trim_proto() {
815        let nested_fields: Fields = vec![Field::new("a", DataType::UInt64, true)].into();
816        let array: ArrayRef = Arc::new(StructArray::new(
817            vec![
818                Field::new("a", DataType::UInt64, true),
819                Field::new("b", DataType::Utf8, true),
820                Field::new_struct("c", nested_fields.clone(), true),
821            ]
822            .into(),
823            vec![
824                Arc::new(UInt64Array::from_iter_values([1])),
825                Arc::new(StringArray::from_iter_values(["large".repeat(50)])),
826                Arc::new(StructArray::new_null(nested_fields, 1)),
827            ],
828            None,
829        ));
830        let bound = ArrayBound::new(array, 0);
831
832        assert_none!(bound.to_proto_lower(0));
833        assert_none!(bound.to_proto_lower(1));
834
835        let proto = bound
836            .to_proto_lower(100)
837            .expect("can fit something in less than 100 bytes");
838        let array = make_array(proto.into_rust().expect("valid proto"));
839        assert_eq!(
840            array.as_struct().column_names().as_slice(),
841            &["a"],
842            "only the first column should fit"
843        );
844
845        let proto = bound
846            .to_proto_lower(1000)
847            .expect("can fit everything in less than 1000 bytes");
848        let array = make_array(proto.into_rust().expect("valid proto"));
849        assert_eq!(
850            array.as_struct().column_names().as_slice(),
851            &["a", "b", "c"],
852            "all columns should fit"
853        )
854    }
855
856    #[mz_ore::test]
857    fn struct_ord() {
858        let prefix = StructArray::new(
859            vec![Field::new("a", DataType::UInt64, true)].into(),
860            vec![Arc::new(UInt64Array::from_iter_values([1, 3, 5]))],
861            None,
862        );
863        let full = StructArray::new(
864            vec![
865                Field::new("a", DataType::UInt64, true),
866                Field::new("b", DataType::Utf8, true),
867            ]
868            .into(),
869            vec![
870                Arc::new(UInt64Array::from_iter_values([2, 3, 4])),
871                Arc::new(StringArray::from_iter_values(["a", "b", "c"])),
872            ],
873            None,
874        );
875        let prefix_ord = ArrayOrd::new(&prefix);
876        let full_ord = ArrayOrd::new(&full);
877
878        // Comparison works as normal over the shared columns... but when those columns are identical,
879        // the shorter struct is always smaller.
880        assert!(prefix_ord.at(0) < full_ord.at(0), "(1) < (2, 'a')");
881        assert!(prefix_ord.at(1) < full_ord.at(1), "(3) < (3, 'b')");
882        assert!(prefix_ord.at(2) > full_ord.at(2), "(5) < (4, 'c')");
883    }
884
885    #[mz_ore::test]
886    #[should_panic(expected = "array types did not match")]
887    fn struct_ord_incompat() {
888        // This test is descriptive, not prescriptive: we declare it is an error to compare
889        // structs like this, but not what the result of comparing them is.
890        let string = StructArray::new(
891            vec![
892                Field::new("a", DataType::UInt64, true),
893                Field::new("b", DataType::Utf8, true),
894            ]
895            .into(),
896            vec![
897                Arc::new(UInt64Array::from_iter_values([1])),
898                Arc::new(StringArray::from_iter_values(["a"])),
899            ],
900            None,
901        );
902        let boolean = StructArray::new(
903            vec![
904                Field::new("a", DataType::UInt64, true),
905                Field::new("b", DataType::Boolean, true),
906            ]
907            .into(),
908            vec![
909                Arc::new(UInt64Array::from_iter_values([1])),
910                Arc::new(BooleanArray::from_iter([Some(true)])),
911            ],
912            None,
913        );
914        let string_ord = ArrayOrd::new(&string);
915        let bool_ord = ArrayOrd::new(&boolean);
916
917        // Despite the matching first column, this will panic with a type mismatch.
918        assert!(string_ord.at(0) < bool_ord.at(0));
919    }
920}