mz_persist_types/
arrow.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A [protobuf] representation of [Apache Arrow] arrays.
11//!
12//! # Motivation
13//!
14//! Persist can store a small amount of data inline at the consensus layer.
15//! Because we are space constrained, we take particular care to store only the
16//! data that is necessary. Other Arrow serialization formats, e.g. [Parquet]
17//! or [Arrow IPC], include data that we don't need and would be wasteful to
18//! store.
19//!
20//! [protobuf]: https://protobuf.dev/
21//! [Apache Arrow]: https://arrow.apache.org/
22//! [Parquet]: https://parquet.apache.org/docs/
23//! [Arrow IPC]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
24
25use std::cmp::Ordering;
26use std::fmt::Display;
27use std::sync::Arc;
28
29use arrow::array::*;
30use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
31use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Fields};
32use itertools::Itertools;
33use mz_ore::cast::CastFrom;
34use mz_ore::soft_assert_eq_no_log;
35use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
36use prost::Message;
37
38#[allow(missing_docs)]
39mod proto {
40    include!(concat!(env!("OUT_DIR"), "/mz_persist_types.arrow.rs"));
41}
42use crate::arrow::proto::data_type;
43pub use proto::{DataType as ProtoDataType, ProtoArrayData};
44
45/// Extract the list of fields for our recursive datatypes.
46pub fn fields_for_type(data_type: &DataType) -> &[FieldRef] {
47    match data_type {
48        DataType::Struct(fields) => fields,
49        DataType::List(field) => std::slice::from_ref(field),
50        DataType::Map(field, _) => std::slice::from_ref(field),
51        DataType::Null
52        | DataType::Boolean
53        | DataType::Int8
54        | DataType::Int16
55        | DataType::Int32
56        | DataType::Int64
57        | DataType::UInt8
58        | DataType::UInt16
59        | DataType::UInt32
60        | DataType::UInt64
61        | DataType::Float16
62        | DataType::Float32
63        | DataType::Float64
64        | DataType::Timestamp(_, _)
65        | DataType::Date32
66        | DataType::Date64
67        | DataType::Time32(_)
68        | DataType::Time64(_)
69        | DataType::Duration(_)
70        | DataType::Interval(_)
71        | DataType::Binary
72        | DataType::FixedSizeBinary(_)
73        | DataType::LargeBinary
74        | DataType::BinaryView
75        | DataType::Utf8
76        | DataType::LargeUtf8
77        | DataType::Utf8View
78        | DataType::Decimal128(_, _)
79        | DataType::Decimal256(_, _) => &[],
80        DataType::ListView(_)
81        | DataType::FixedSizeList(_, _)
82        | DataType::LargeList(_)
83        | DataType::LargeListView(_)
84        | DataType::Union(_, _)
85        | DataType::Dictionary(_, _)
86        | DataType::RunEndEncoded(_, _) => unimplemented!("not supported"),
87    }
88}
89
90/// Encode the array into proto. If an expected data type is passed, that implies it is
91/// encoded at some higher level, and we omit it from the data.
92fn into_proto_with_type(data: &ArrayData, expected_type: Option<&DataType>) -> ProtoArrayData {
93    let data_type = match expected_type {
94        Some(expected) => {
95            // Equality is recursive, and this function is itself called recursively,
96            // skip the call in production to avoid a quadratic overhead.
97            soft_assert_eq_no_log!(
98                expected,
99                data.data_type(),
100                "actual type should match expected type"
101            );
102            None
103        }
104        None => Some(data.data_type().into_proto()),
105    };
106
107    ProtoArrayData {
108        data_type,
109        length: u64::cast_from(data.len()),
110        offset: u64::cast_from(data.offset()),
111        buffers: data.buffers().iter().map(|b| b.into_proto()).collect(),
112        children: data
113            .child_data()
114            .iter()
115            .zip_eq(fields_for_type(
116                expected_type.unwrap_or_else(|| data.data_type()),
117            ))
118            .map(|(child, expect)| into_proto_with_type(child, Some(expect.data_type())))
119            .collect(),
120        nulls: data.nulls().map(|n| n.inner().into_proto()),
121    }
122}
123
124/// Decode the array data.
125/// If the data type is omitted from the proto, we decode it as the expected type.
126fn from_proto_with_type(
127    proto: ProtoArrayData,
128    expected_type: Option<&DataType>,
129) -> Result<ArrayData, TryFromProtoError> {
130    let ProtoArrayData {
131        data_type,
132        length,
133        offset,
134        buffers,
135        children,
136        nulls,
137    } = proto;
138    let data_type: Option<DataType> = data_type.into_rust()?;
139    let data_type = match (data_type, expected_type) {
140        (Some(data_type), None) => data_type,
141        (Some(data_type), Some(expected_type)) => {
142            // Equality is recursive, and this function is itself called recursively,
143            // skip the call in production to avoid a quadratic overhead.
144            soft_assert_eq_no_log!(
145                data_type,
146                *expected_type,
147                "expected type should match actual type"
148            );
149            data_type
150        }
151        (None, Some(expected_type)) => expected_type.clone(),
152        (None, None) => {
153            return Err(TryFromProtoError::MissingField(
154                "ProtoArrayData::data_type".to_string(),
155            ));
156        }
157    };
158    let nulls = nulls
159        .map(|n| n.into_rust())
160        .transpose()?
161        .map(NullBuffer::new);
162
163    let mut builder = ArrayDataBuilder::new(data_type.clone())
164        .len(usize::cast_from(length))
165        .offset(usize::cast_from(offset))
166        .nulls(nulls);
167
168    for b in buffers.into_iter().map(|b| b.into_rust()) {
169        builder = builder.add_buffer(b?);
170    }
171    for c in children
172        .into_iter()
173        .zip_eq(fields_for_type(&data_type))
174        .map(|(c, field)| from_proto_with_type(c, Some(field.data_type())))
175    {
176        builder = builder.add_child_data(c?);
177    }
178
179    // Construct the builder which validates all inputs and aligns data.
180    builder
181        .build_aligned()
182        .map_err(|e| TryFromProtoError::RowConversionError(e.to_string()))
183}
184
185impl RustType<ProtoArrayData> for arrow::array::ArrayData {
186    fn into_proto(&self) -> ProtoArrayData {
187        into_proto_with_type(self, None)
188    }
189
190    fn from_proto(proto: ProtoArrayData) -> Result<Self, TryFromProtoError> {
191        from_proto_with_type(proto, None)
192    }
193}
194
195impl RustType<proto::DataType> for arrow::datatypes::DataType {
196    fn into_proto(&self) -> proto::DataType {
197        let kind = match self {
198            DataType::Null => proto::data_type::Kind::Null(()),
199            DataType::Boolean => proto::data_type::Kind::Boolean(()),
200            DataType::UInt8 => proto::data_type::Kind::Uint8(()),
201            DataType::UInt16 => proto::data_type::Kind::Uint16(()),
202            DataType::UInt32 => proto::data_type::Kind::Uint32(()),
203            DataType::UInt64 => proto::data_type::Kind::Uint64(()),
204            DataType::Int8 => proto::data_type::Kind::Int8(()),
205            DataType::Int16 => proto::data_type::Kind::Int16(()),
206            DataType::Int32 => proto::data_type::Kind::Int32(()),
207            DataType::Int64 => proto::data_type::Kind::Int64(()),
208            DataType::Float32 => proto::data_type::Kind::Float32(()),
209            DataType::Float64 => proto::data_type::Kind::Float64(()),
210            DataType::Utf8 => proto::data_type::Kind::String(()),
211            DataType::Binary => proto::data_type::Kind::Binary(()),
212            DataType::FixedSizeBinary(size) => proto::data_type::Kind::FixedBinary(*size),
213            DataType::List(inner) => proto::data_type::Kind::List(Box::new(inner.into_proto())),
214            DataType::Map(inner, sorted) => {
215                let map = proto::data_type::Map {
216                    value: Some(Box::new(inner.into_proto())),
217                    sorted: *sorted,
218                };
219                proto::data_type::Kind::Map(Box::new(map))
220            }
221            DataType::Struct(children) => {
222                let children = children.into_iter().map(|f| f.into_proto()).collect();
223                proto::data_type::Kind::Struct(proto::data_type::Struct { children })
224            }
225            other => unimplemented!("unsupported data type {other:?}"),
226        };
227
228        proto::DataType { kind: Some(kind) }
229    }
230
231    fn from_proto(proto: proto::DataType) -> Result<Self, TryFromProtoError> {
232        let data_type = proto
233            .kind
234            .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
235        let data_type = match data_type {
236            proto::data_type::Kind::Null(()) => DataType::Null,
237            proto::data_type::Kind::Boolean(()) => DataType::Boolean,
238            proto::data_type::Kind::Uint8(()) => DataType::UInt8,
239            proto::data_type::Kind::Uint16(()) => DataType::UInt16,
240            proto::data_type::Kind::Uint32(()) => DataType::UInt32,
241            proto::data_type::Kind::Uint64(()) => DataType::UInt64,
242            proto::data_type::Kind::Int8(()) => DataType::Int8,
243            proto::data_type::Kind::Int16(()) => DataType::Int16,
244            proto::data_type::Kind::Int32(()) => DataType::Int32,
245            proto::data_type::Kind::Int64(()) => DataType::Int64,
246            proto::data_type::Kind::Float32(()) => DataType::Float32,
247            proto::data_type::Kind::Float64(()) => DataType::Float64,
248            proto::data_type::Kind::String(()) => DataType::Utf8,
249            proto::data_type::Kind::Binary(()) => DataType::Binary,
250            proto::data_type::Kind::FixedBinary(size) => DataType::FixedSizeBinary(size),
251            proto::data_type::Kind::List(inner) => DataType::List(Arc::new((*inner).into_rust()?)),
252            proto::data_type::Kind::Map(inner) => {
253                let value = inner
254                    .value
255                    .ok_or_else(|| TryFromProtoError::missing_field("map.value"))?;
256                DataType::Map(Arc::new((*value).into_rust()?), inner.sorted)
257            }
258            proto::data_type::Kind::Struct(inner) => {
259                let children: Vec<Field> = inner
260                    .children
261                    .into_iter()
262                    .map(|c| c.into_rust())
263                    .collect::<Result<_, _>>()?;
264                DataType::Struct(Fields::from(children))
265            }
266        };
267
268        Ok(data_type)
269    }
270}
271
272impl RustType<proto::Field> for arrow::datatypes::Field {
273    fn into_proto(&self) -> proto::Field {
274        proto::Field {
275            name: self.name().clone(),
276            nullable: self.is_nullable(),
277            data_type: Some(Box::new(self.data_type().into_proto())),
278        }
279    }
280
281    fn from_proto(proto: proto::Field) -> Result<Self, TryFromProtoError> {
282        let proto::Field {
283            name,
284            nullable,
285            data_type,
286        } = proto;
287        let data_type =
288            data_type.ok_or_else(|| TryFromProtoError::missing_field("field.data_type"))?;
289        let data_type = (*data_type).into_rust()?;
290
291        Ok(Field::new(name, data_type, nullable))
292    }
293}
294
295impl RustType<proto::Buffer> for arrow::buffer::Buffer {
296    fn into_proto(&self) -> proto::Buffer {
297        // TODO(parkmycar): There is probably something better we can do here.
298        proto::Buffer {
299            data: bytes::Bytes::copy_from_slice(self.as_slice()),
300        }
301    }
302
303    fn from_proto(proto: proto::Buffer) -> Result<Self, TryFromProtoError> {
304        Ok(arrow::buffer::Buffer::from_bytes(proto.data.into()))
305    }
306}
307
308impl RustType<proto::BooleanBuffer> for arrow::buffer::BooleanBuffer {
309    fn into_proto(&self) -> proto::BooleanBuffer {
310        proto::BooleanBuffer {
311            buffer: Some(self.sliced().into_proto()),
312            length: u64::cast_from(self.len()),
313        }
314    }
315
316    fn from_proto(proto: proto::BooleanBuffer) -> Result<Self, TryFromProtoError> {
317        let proto::BooleanBuffer { buffer, length } = proto;
318        let buffer = buffer.into_rust_if_some("buffer")?;
319        Ok(BooleanBuffer::new(buffer, 0, usize::cast_from(length)))
320    }
321}
322
323/// Wraps a single arrow array, downcasted to a specific type.
324#[derive(Clone, Debug)]
325pub enum ArrayOrd {
326    /// Wraps a `NullArray`.
327    Null(NullArray),
328    /// Wraps a `Bool` array.
329    Bool(BooleanArray),
330    /// Wraps a `Int8` array.
331    Int8(Int8Array),
332    /// Wraps a `Int16` array.
333    Int16(Int16Array),
334    /// Wraps a `Int32` array.
335    Int32(Int32Array),
336    /// Wraps a `Int64` array.
337    Int64(Int64Array),
338    /// Wraps a `UInt8` array.
339    UInt8(UInt8Array),
340    /// Wraps a `UInt16` array.
341    UInt16(UInt16Array),
342    /// Wraps a `UInt32` array.
343    UInt32(UInt32Array),
344    /// Wraps a `UInt64` array.
345    UInt64(UInt64Array),
346    /// Wraps a `Float32` array.
347    Float32(Float32Array),
348    /// Wraps a `Float64` array.
349    Float64(Float64Array),
350    /// Wraps a `String` array.
351    String(StringArray),
352    /// Wraps a `Binary` array.
353    Binary(BinaryArray),
354    /// Wraps a `FixedSizeBinary` array.
355    FixedSizeBinary(FixedSizeBinaryArray),
356    /// Wraps a `List` array.
357    List(Option<NullBuffer>, OffsetBuffer<i32>, Box<ArrayOrd>),
358    /// Wraps a `Struct` array.
359    Struct(Option<NullBuffer>, Vec<ArrayOrd>),
360}
361
362impl ArrayOrd {
363    /// Downcast the provided array to a specific type in our enum.
364    pub fn new(array: &dyn Array) -> Self {
365        match array.data_type() {
366            DataType::Null => ArrayOrd::Null(NullArray::from(array.to_data())),
367            DataType::Boolean => ArrayOrd::Bool(array.as_boolean().clone()),
368            DataType::Int8 => ArrayOrd::Int8(array.as_primitive().clone()),
369            DataType::Int16 => ArrayOrd::Int16(array.as_primitive().clone()),
370            DataType::Int32 => ArrayOrd::Int32(array.as_primitive().clone()),
371            DataType::Int64 => ArrayOrd::Int64(array.as_primitive().clone()),
372            DataType::UInt8 => ArrayOrd::UInt8(array.as_primitive().clone()),
373            DataType::UInt16 => ArrayOrd::UInt16(array.as_primitive().clone()),
374            DataType::UInt32 => ArrayOrd::UInt32(array.as_primitive().clone()),
375            DataType::UInt64 => ArrayOrd::UInt64(array.as_primitive().clone()),
376            DataType::Float32 => ArrayOrd::Float32(array.as_primitive().clone()),
377            DataType::Float64 => ArrayOrd::Float64(array.as_primitive().clone()),
378            DataType::Binary => ArrayOrd::Binary(array.as_binary().clone()),
379            DataType::Utf8 => ArrayOrd::String(array.as_string().clone()),
380            DataType::FixedSizeBinary(_) => {
381                ArrayOrd::FixedSizeBinary(array.as_fixed_size_binary().clone())
382            }
383            DataType::List(_) => {
384                let list_array = array.as_list();
385                ArrayOrd::List(
386                    list_array.nulls().cloned(),
387                    list_array.offsets().clone(),
388                    Box::new(ArrayOrd::new(list_array.values())),
389                )
390            }
391            DataType::Struct(_) => {
392                let struct_array = array.as_struct();
393                let nulls = array.nulls().cloned();
394                let columns: Vec<_> = struct_array
395                    .columns()
396                    .iter()
397                    .map(|a| ArrayOrd::new(a))
398                    .collect();
399                ArrayOrd::Struct(nulls, columns)
400            }
401            data_type => unimplemented!("array type {data_type:?} not yet supported"),
402        }
403    }
404
405    /// Returns the rough amount of space required for the data in this array in bytes.
406    /// (Not counting nulls, dictionary encoding, or other space optimizations.)
407    pub fn goodbytes(&self) -> usize {
408        match self {
409            ArrayOrd::Null(_) => 0,
410            // This is, strictly speaking, wrong - but consistent with `ArrayIdx::goodbytes`,
411            // which counts one byte per bool.
412            ArrayOrd::Bool(b) => b.len(),
413            ArrayOrd::Int8(a) => a.values().inner().len(),
414            ArrayOrd::Int16(a) => a.values().inner().len(),
415            ArrayOrd::Int32(a) => a.values().inner().len(),
416            ArrayOrd::Int64(a) => a.values().inner().len(),
417            ArrayOrd::UInt8(a) => a.values().inner().len(),
418            ArrayOrd::UInt16(a) => a.values().inner().len(),
419            ArrayOrd::UInt32(a) => a.values().inner().len(),
420            ArrayOrd::UInt64(a) => a.values().inner().len(),
421            ArrayOrd::Float32(a) => a.values().inner().len(),
422            ArrayOrd::Float64(a) => a.values().inner().len(),
423            ArrayOrd::String(a) => a.values().len(),
424            ArrayOrd::Binary(a) => a.values().len(),
425            ArrayOrd::FixedSizeBinary(a) => a.values().len(),
426            ArrayOrd::List(_, _, nested) => nested.goodbytes(),
427            ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.goodbytes()).sum(),
428        }
429    }
430
431    /// Return a struct representing the value at a particular index in this array.
432    pub fn at(&self, idx: usize) -> ArrayIdx {
433        ArrayIdx { idx, array: self }
434    }
435}
436
437/// A struct representing a particular entry in a particular array. Most useful for its `Ord`
438/// implementation, which can compare entire rows across similarly-typed arrays.
439///
440/// It is an error to compare indices from arrays with different types, with one exception:
441/// it is valid to compare two `StructArray`s, one of which is a prefix of the other...
442/// in which case we'll compare the values on that subset of the fields, and the shorter
443/// of the two structs will compare less if they're otherwise equal.
444#[derive(Clone, Copy, Debug)]
445pub struct ArrayIdx<'a> {
446    /// An index into a particular array.
447    pub idx: usize,
448    /// The particular array.
449    pub array: &'a ArrayOrd,
450}
451
452impl Display for ArrayIdx<'_> {
453    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454        match self.array {
455            ArrayOrd::Null(_) => write!(f, "null"),
456            ArrayOrd::Bool(a) => write!(f, "{}", a.value(self.idx)),
457            ArrayOrd::Int8(a) => write!(f, "{}", a.value(self.idx)),
458            ArrayOrd::Int16(a) => write!(f, "{}", a.value(self.idx)),
459            ArrayOrd::Int32(a) => write!(f, "{}", a.value(self.idx)),
460            ArrayOrd::Int64(a) => write!(f, "{}", a.value(self.idx)),
461            ArrayOrd::UInt8(a) => write!(f, "{}", a.value(self.idx)),
462            ArrayOrd::UInt16(a) => write!(f, "{}", a.value(self.idx)),
463            ArrayOrd::UInt32(a) => write!(f, "{}", a.value(self.idx)),
464            ArrayOrd::UInt64(a) => write!(f, "{}", a.value(self.idx)),
465            ArrayOrd::Float32(a) => write!(f, "{}", a.value(self.idx)),
466            ArrayOrd::Float64(a) => write!(f, "{}", a.value(self.idx)),
467            ArrayOrd::String(a) => write!(f, "{}", a.value(self.idx)),
468            ArrayOrd::Binary(a) => {
469                for byte in a.value(self.idx) {
470                    write!(f, "{:02x}", byte)?;
471                }
472                Ok(())
473            }
474            ArrayOrd::FixedSizeBinary(a) => {
475                for byte in a.value(self.idx) {
476                    write!(f, "{:02x}", byte)?;
477                }
478                Ok(())
479            }
480            ArrayOrd::List(_, offsets, nested) => {
481                write!(
482                    f,
483                    "[{}]",
484                    mz_ore::str::separated(", ", list_range(offsets, nested, self.idx))
485                )
486            }
487            ArrayOrd::Struct(_, nested) => write!(
488                f,
489                "{{{}}}",
490                mz_ore::str::separated(", ", nested.iter().map(|f| f.at(self.idx)))
491            ),
492        }
493    }
494}
495
496#[inline]
497fn list_range<'a>(
498    offsets: &OffsetBuffer<i32>,
499    values: &'a ArrayOrd,
500    idx: usize,
501) -> impl Iterator<Item = ArrayIdx<'a>> + Clone {
502    let offsets = offsets.inner();
503    let from = offsets[idx].as_usize();
504    let to = offsets[idx + 1].as_usize();
505    (from..to).map(|i| values.at(i))
506}
507
508impl<'a> ArrayIdx<'a> {
509    /// Returns the rough amount of space required for this entry in bytes.
510    /// (Not counting nulls, dictionary encoding, or other space optimizations.)
511    pub fn goodbytes(&self) -> usize {
512        match self.array {
513            ArrayOrd::Null(_) => 0,
514            ArrayOrd::Bool(_) => size_of::<bool>(),
515            ArrayOrd::Int8(_) => size_of::<i8>(),
516            ArrayOrd::Int16(_) => size_of::<i16>(),
517            ArrayOrd::Int32(_) => size_of::<i32>(),
518            ArrayOrd::Int64(_) => size_of::<i64>(),
519            ArrayOrd::UInt8(_) => size_of::<u8>(),
520            ArrayOrd::UInt16(_) => size_of::<u16>(),
521            ArrayOrd::UInt32(_) => size_of::<u32>(),
522            ArrayOrd::UInt64(_) => size_of::<u64>(),
523            ArrayOrd::Float32(_) => size_of::<f32>(),
524            ArrayOrd::Float64(_) => size_of::<f64>(),
525            ArrayOrd::String(a) => a.value(self.idx).len(),
526            ArrayOrd::Binary(a) => a.value(self.idx).len(),
527            ArrayOrd::FixedSizeBinary(a) => a.value_length().as_usize(),
528            ArrayOrd::List(_, offsets, nested) => {
529                // Range over the list, summing up the bytes for each entry.
530                list_range(offsets, nested, self.idx)
531                    .map(|a| a.goodbytes())
532                    .sum()
533            }
534            ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.at(self.idx).goodbytes()).sum(),
535        }
536    }
537}
538
539impl<'a> Ord for ArrayIdx<'a> {
540    fn cmp(&self, other: &Self) -> Ordering {
541        #[inline]
542        fn is_null(buffer: &Option<NullBuffer>, idx: usize) -> bool {
543            buffer.as_ref().map_or(false, |b| b.is_null(idx))
544        }
545        #[inline]
546        fn cmp<A: ArrayAccessor>(
547            left: A,
548            left_idx: usize,
549            right: A,
550            right_idx: usize,
551            cmp: fn(&A::Item, &A::Item) -> Ordering,
552        ) -> Ordering {
553            // NB: nulls sort last, conveniently matching psql / mz_repr
554            match (left.is_null(left_idx), right.is_null(right_idx)) {
555                (false, true) => Ordering::Less,
556                (true, true) => Ordering::Equal,
557                (true, false) => Ordering::Greater,
558                (false, false) => cmp(&left.value(left_idx), &right.value(right_idx)),
559            }
560        }
561        match (&self.array, &other.array) {
562            (ArrayOrd::Null(s), ArrayOrd::Null(o)) => {
563                debug_assert!(
564                    self.idx < s.len() && other.idx < o.len(),
565                    "null array indices in bounds"
566                );
567                Ordering::Equal
568            }
569            // For arrays with "simple" value types, we fetch and compare the underlying values directly.
570            (ArrayOrd::Bool(s), ArrayOrd::Bool(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
571            (ArrayOrd::Int8(s), ArrayOrd::Int8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
572            (ArrayOrd::Int16(s), ArrayOrd::Int16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
573            (ArrayOrd::Int32(s), ArrayOrd::Int32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
574            (ArrayOrd::Int64(s), ArrayOrd::Int64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
575            (ArrayOrd::UInt8(s), ArrayOrd::UInt8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
576            (ArrayOrd::UInt16(s), ArrayOrd::UInt16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
577            (ArrayOrd::UInt32(s), ArrayOrd::UInt32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
578            (ArrayOrd::UInt64(s), ArrayOrd::UInt64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
579            (ArrayOrd::Float32(s), ArrayOrd::Float32(o)) => {
580                cmp(s, self.idx, o, other.idx, f32::total_cmp)
581            }
582            (ArrayOrd::Float64(s), ArrayOrd::Float64(o)) => {
583                cmp(s, self.idx, o, other.idx, f64::total_cmp)
584            }
585            (ArrayOrd::String(s), ArrayOrd::String(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
586            (ArrayOrd::Binary(s), ArrayOrd::Binary(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
587            (ArrayOrd::FixedSizeBinary(s), ArrayOrd::FixedSizeBinary(o)) => {
588                cmp(s, self.idx, o, other.idx, Ord::cmp)
589            }
590            // For lists, we generate an iterator for each side that ranges over the correct
591            // indices into the value buffer, then compare them lexicographically.
592            (
593                ArrayOrd::List(s_nulls, s_offset, s_values),
594                ArrayOrd::List(o_nulls, o_offset, o_values),
595            ) => match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
596                (false, true) => Ordering::Less,
597                (true, true) => Ordering::Equal,
598                (true, false) => Ordering::Greater,
599                (false, false) => list_range(s_offset, s_values, self.idx)
600                    .cmp(list_range(o_offset, o_values, other.idx)),
601            },
602            // For structs, we iterate over the same index in each field for each input,
603            // comparing them lexicographically in order.
604            (ArrayOrd::Struct(s_nulls, s_cols), ArrayOrd::Struct(o_nulls, o_cols)) => {
605                match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
606                    (false, true) => Ordering::Less,
607                    (true, true) => Ordering::Equal,
608                    (true, false) => Ordering::Greater,
609                    (false, false) => {
610                        let s = s_cols.iter().map(|array| array.at(self.idx));
611                        let o = o_cols.iter().map(|array| array.at(other.idx));
612                        s.cmp(o)
613                    }
614                }
615            }
616            (_, _) => panic!("array types did not match"),
617        }
618    }
619}
620
621impl<'a> PartialOrd for ArrayIdx<'a> {
622    fn partial_cmp(&self, other: &ArrayIdx) -> Option<Ordering> {
623        Some(self.cmp(other))
624    }
625}
626
627impl<'a> PartialEq for ArrayIdx<'a> {
628    fn eq(&self, other: &ArrayIdx) -> bool {
629        self.cmp(other) == Ordering::Equal
630    }
631}
632
633impl<'a> Eq for ArrayIdx<'a> {}
634
635/// An array with precisely one entry, for use as a lower bound.
636#[derive(Debug, Clone)]
637pub struct ArrayBound {
638    raw: ArrayRef,
639    ord: ArrayOrd,
640    index: usize,
641}
642
643impl PartialEq for ArrayBound {
644    fn eq(&self, other: &Self) -> bool {
645        self.get().eq(&other.get())
646    }
647}
648
649impl Eq for ArrayBound {}
650
651impl ArrayBound {
652    /// Create a new `ArrayBound` for this array, with the bound at the provided index.
653    pub fn new(array: ArrayRef, index: usize) -> Self {
654        Self {
655            ord: ArrayOrd::new(array.as_ref()),
656            raw: array,
657            index,
658        }
659    }
660
661    /// Get the value of the bound.
662    pub fn get(&self) -> ArrayIdx {
663        self.ord.at(self.index)
664    }
665
666    /// Convert to an array-data proto, respecting a maximum data size. The resulting proto will
667    /// decode to a single-row array, such that `ArrayBound::new(decoded, 0).get() <= self.get()`,
668    /// which makes it suitable as a lower bound.
669    pub fn to_proto_lower(&self, max_len: usize) -> Option<ProtoArrayData> {
670        // Use `take` instead of slice to make sure we encode just the relevant row to proto,
671        // instead of some larger buffer with an offset.
672        let indices = UInt64Array::from_value(u64::usize_as(self.index), 1);
673        let taken = arrow::compute::take(self.raw.as_ref(), &indices, None).ok()?;
674        let array_data = taken.into_data();
675
676        let mut proto = array_data.into_proto();
677        let original_len = proto.encoded_len();
678        if original_len <= max_len {
679            return Some(proto);
680        }
681
682        let mut data_type = proto.data_type.take()?;
683        maybe_trim_proto(&mut data_type, &mut proto, max_len);
684        proto.data_type = Some(data_type);
685
686        if cfg!(debug_assertions) {
687            let array: ArrayData = proto
688                .clone()
689                .into_rust()
690                .expect("trimmed array data can still be decoded");
691            assert_eq!(array.len(), 1);
692            let new_bound = Self::new(make_array(array), 0);
693            assert!(
694                new_bound.get() <= self.get(),
695                "trimmed bound should be comparable to and no larger than the original data"
696            )
697        }
698
699        if proto.encoded_len() <= max_len {
700            Some(proto)
701        } else {
702            None
703        }
704    }
705}
706
707/// Makes a best effort to shrink the proto while preserving the ordering.
708/// (The proto might not be smaller after this method is called, but it should always
709/// be a valid lower bound.)
710///
711/// Note that we pass in the data type and the array data separately, since we only keep
712/// type info at the top level. If a caller does have a top-level `ArrayData` instance,
713/// they should take that type and pass it in separately.
714fn maybe_trim_proto(data_type: &mut proto::DataType, body: &mut ProtoArrayData, max_len: usize) {
715    assert!(body.data_type.is_none(), "expected separate data type");
716    // TODO: consider adding cases for strings and byte arrays
717    let encoded_len = data_type.encoded_len() + body.encoded_len();
718    match &mut data_type.kind {
719        Some(data_type::Kind::Struct(data_type::Struct { children: fields })) => {
720            // Pop off fields one by one, keeping an estimate of the encoded length.
721            let mut struct_len = encoded_len;
722            while struct_len > max_len {
723                let Some(mut child) = body.children.pop() else {
724                    break;
725                };
726                let Some(mut field) = fields.pop() else { break };
727
728                struct_len -= field.encoded_len() + child.encoded_len();
729                if let Some(remaining_len) = max_len.checked_sub(struct_len) {
730                    // We're under budget after removing this field! See if we can
731                    // shrink it to fit, but exit the loop regardless.
732                    let Some(field_type) = field.data_type.as_mut() else {
733                        break;
734                    };
735                    maybe_trim_proto(field_type, &mut child, remaining_len);
736                    if field.encoded_len() + child.encoded_len() <= remaining_len {
737                        fields.push(field);
738                        body.children.push(child);
739                    }
740                    break;
741                }
742            }
743        }
744        _ => {}
745    };
746}
747
748#[cfg(test)]
749mod tests {
750    use crate::arrow::{ArrayBound, ArrayOrd};
751    use arrow::array::{
752        ArrayRef, AsArray, BooleanArray, StringArray, StructArray, UInt64Array, make_array,
753    };
754    use arrow::datatypes::{DataType, Field, Fields};
755    use mz_ore::assert_none;
756    use mz_proto::ProtoType;
757    use std::sync::Arc;
758
759    #[mz_ore::test]
760    fn trim_proto() {
761        let nested_fields: Fields = vec![Field::new("a", DataType::UInt64, true)].into();
762        let array: ArrayRef = Arc::new(StructArray::new(
763            vec![
764                Field::new("a", DataType::UInt64, true),
765                Field::new("b", DataType::Utf8, true),
766                Field::new_struct("c", nested_fields.clone(), true),
767            ]
768            .into(),
769            vec![
770                Arc::new(UInt64Array::from_iter_values([1])),
771                Arc::new(StringArray::from_iter_values(["large".repeat(50)])),
772                Arc::new(StructArray::new_null(nested_fields, 1)),
773            ],
774            None,
775        ));
776        let bound = ArrayBound::new(array, 0);
777
778        assert_none!(bound.to_proto_lower(0));
779        assert_none!(bound.to_proto_lower(1));
780
781        let proto = bound
782            .to_proto_lower(100)
783            .expect("can fit something in less than 100 bytes");
784        let array = make_array(proto.into_rust().expect("valid proto"));
785        assert_eq!(
786            array.as_struct().column_names().as_slice(),
787            &["a"],
788            "only the first column should fit"
789        );
790
791        let proto = bound
792            .to_proto_lower(1000)
793            .expect("can fit everything in less than 1000 bytes");
794        let array = make_array(proto.into_rust().expect("valid proto"));
795        assert_eq!(
796            array.as_struct().column_names().as_slice(),
797            &["a", "b", "c"],
798            "all columns should fit"
799        )
800    }
801
802    #[mz_ore::test]
803    fn struct_ord() {
804        let prefix = StructArray::new(
805            vec![Field::new("a", DataType::UInt64, true)].into(),
806            vec![Arc::new(UInt64Array::from_iter_values([1, 3, 5]))],
807            None,
808        );
809        let full = StructArray::new(
810            vec![
811                Field::new("a", DataType::UInt64, true),
812                Field::new("b", DataType::Utf8, true),
813            ]
814            .into(),
815            vec![
816                Arc::new(UInt64Array::from_iter_values([2, 3, 4])),
817                Arc::new(StringArray::from_iter_values(["a", "b", "c"])),
818            ],
819            None,
820        );
821        let prefix_ord = ArrayOrd::new(&prefix);
822        let full_ord = ArrayOrd::new(&full);
823
824        // Comparison works as normal over the shared columns... but when those columns are identical,
825        // the shorter struct is always smaller.
826        assert!(prefix_ord.at(0) < full_ord.at(0), "(1) < (2, 'a')");
827        assert!(prefix_ord.at(1) < full_ord.at(1), "(3) < (3, 'b')");
828        assert!(prefix_ord.at(2) > full_ord.at(2), "(5) < (4, 'c')");
829    }
830
831    #[mz_ore::test]
832    #[should_panic(expected = "array types did not match")]
833    fn struct_ord_incompat() {
834        // This test is descriptive, not prescriptive: we declare it is an error to compare
835        // structs like this, but not what the result of comparing them is.
836        let string = StructArray::new(
837            vec![
838                Field::new("a", DataType::UInt64, true),
839                Field::new("b", DataType::Utf8, true),
840            ]
841            .into(),
842            vec![
843                Arc::new(UInt64Array::from_iter_values([1])),
844                Arc::new(StringArray::from_iter_values(["a"])),
845            ],
846            None,
847        );
848        let boolean = StructArray::new(
849            vec![
850                Field::new("a", DataType::UInt64, true),
851                Field::new("b", DataType::Boolean, true),
852            ]
853            .into(),
854            vec![
855                Arc::new(UInt64Array::from_iter_values([1])),
856                Arc::new(BooleanArray::from_iter([Some(true)])),
857            ],
858            None,
859        );
860        let string_ord = ArrayOrd::new(&string);
861        let bool_ord = ArrayOrd::new(&boolean);
862
863        // Despite the matching first column, this will panic with a type mismatch.
864        assert!(string_ord.at(0) < bool_ord.at(0));
865    }
866}