1use std::cmp::Ordering;
26use std::fmt::Display;
27use std::sync::Arc;
28
29use arrow::array::*;
30use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
31use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Fields};
32use itertools::Itertools;
33use mz_ore::cast::CastFrom;
34use mz_ore::soft_assert_eq_no_log;
35use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
36use prost::Message;
37
38#[allow(missing_docs)]
39mod proto {
40 include!(concat!(env!("OUT_DIR"), "/mz_persist_types.arrow.rs"));
41}
42use crate::arrow::proto::data_type;
43pub use proto::{DataType as ProtoDataType, ProtoArrayData};
44
45pub fn fields_for_type(data_type: &DataType) -> &[FieldRef] {
47 match data_type {
48 DataType::Struct(fields) => fields,
49 DataType::List(field) => std::slice::from_ref(field),
50 DataType::Map(field, _) => std::slice::from_ref(field),
51 DataType::Null
52 | DataType::Boolean
53 | DataType::Int8
54 | DataType::Int16
55 | DataType::Int32
56 | DataType::Int64
57 | DataType::UInt8
58 | DataType::UInt16
59 | DataType::UInt32
60 | DataType::UInt64
61 | DataType::Float16
62 | DataType::Float32
63 | DataType::Float64
64 | DataType::Timestamp(_, _)
65 | DataType::Date32
66 | DataType::Date64
67 | DataType::Time32(_)
68 | DataType::Time64(_)
69 | DataType::Duration(_)
70 | DataType::Interval(_)
71 | DataType::Binary
72 | DataType::FixedSizeBinary(_)
73 | DataType::LargeBinary
74 | DataType::BinaryView
75 | DataType::Utf8
76 | DataType::LargeUtf8
77 | DataType::Utf8View
78 | DataType::Decimal128(_, _)
79 | DataType::Decimal256(_, _) => &[],
80 DataType::ListView(_)
81 | DataType::FixedSizeList(_, _)
82 | DataType::LargeList(_)
83 | DataType::LargeListView(_)
84 | DataType::Union(_, _)
85 | DataType::Dictionary(_, _)
86 | DataType::RunEndEncoded(_, _) => unimplemented!("not supported"),
87 }
88}
89
90fn into_proto_with_type(data: &ArrayData, expected_type: Option<&DataType>) -> ProtoArrayData {
93 let data_type = match expected_type {
94 Some(expected) => {
95 soft_assert_eq_no_log!(
98 expected,
99 data.data_type(),
100 "actual type should match expected type"
101 );
102 None
103 }
104 None => Some(data.data_type().into_proto()),
105 };
106
107 ProtoArrayData {
108 data_type,
109 length: u64::cast_from(data.len()),
110 offset: u64::cast_from(data.offset()),
111 buffers: data.buffers().iter().map(|b| b.into_proto()).collect(),
112 children: data
113 .child_data()
114 .iter()
115 .zip_eq(fields_for_type(expected_type.unwrap_or(data.data_type())))
116 .map(|(child, expect)| into_proto_with_type(child, Some(expect.data_type())))
117 .collect(),
118 nulls: data.nulls().map(|n| n.inner().into_proto()),
119 }
120}
121
122fn from_proto_with_type(
125 proto: ProtoArrayData,
126 expected_type: Option<&DataType>,
127) -> Result<ArrayData, TryFromProtoError> {
128 let ProtoArrayData {
129 data_type,
130 length,
131 offset,
132 buffers,
133 children,
134 nulls,
135 } = proto;
136 let data_type: Option<DataType> = data_type.into_rust()?;
137 let data_type = match (data_type, expected_type) {
138 (Some(data_type), None) => data_type,
139 (Some(data_type), Some(expected_type)) => {
140 soft_assert_eq_no_log!(
143 data_type,
144 *expected_type,
145 "expected type should match actual type"
146 );
147 data_type
148 }
149 (None, Some(expected_type)) => expected_type.clone(),
150 (None, None) => {
151 return Err(TryFromProtoError::MissingField(
152 "ProtoArrayData::data_type".to_string(),
153 ));
154 }
155 };
156 let nulls = nulls
157 .map(|n| n.into_rust())
158 .transpose()?
159 .map(NullBuffer::new);
160
161 let mut builder = ArrayDataBuilder::new(data_type.clone())
162 .len(usize::cast_from(length))
163 .offset(usize::cast_from(offset))
164 .nulls(nulls);
165
166 for b in buffers.into_iter().map(|b| b.into_rust()) {
167 builder = builder.add_buffer(b?);
168 }
169 for c in children
170 .into_iter()
171 .zip_eq(fields_for_type(&data_type))
172 .map(|(c, field)| from_proto_with_type(c, Some(field.data_type())))
173 {
174 builder = builder.add_child_data(c?);
175 }
176
177 builder
179 .build_aligned()
180 .map_err(|e| TryFromProtoError::RowConversionError(e.to_string()))
181}
182
183impl RustType<ProtoArrayData> for arrow::array::ArrayData {
184 fn into_proto(&self) -> ProtoArrayData {
185 into_proto_with_type(self, None)
186 }
187
188 fn from_proto(proto: ProtoArrayData) -> Result<Self, TryFromProtoError> {
189 from_proto_with_type(proto, None)
190 }
191}
192
193impl RustType<proto::DataType> for arrow::datatypes::DataType {
194 fn into_proto(&self) -> proto::DataType {
195 let kind = match self {
196 DataType::Null => proto::data_type::Kind::Null(()),
197 DataType::Boolean => proto::data_type::Kind::Boolean(()),
198 DataType::UInt8 => proto::data_type::Kind::Uint8(()),
199 DataType::UInt16 => proto::data_type::Kind::Uint16(()),
200 DataType::UInt32 => proto::data_type::Kind::Uint32(()),
201 DataType::UInt64 => proto::data_type::Kind::Uint64(()),
202 DataType::Int8 => proto::data_type::Kind::Int8(()),
203 DataType::Int16 => proto::data_type::Kind::Int16(()),
204 DataType::Int32 => proto::data_type::Kind::Int32(()),
205 DataType::Int64 => proto::data_type::Kind::Int64(()),
206 DataType::Float32 => proto::data_type::Kind::Float32(()),
207 DataType::Float64 => proto::data_type::Kind::Float64(()),
208 DataType::Utf8 => proto::data_type::Kind::String(()),
209 DataType::Binary => proto::data_type::Kind::Binary(()),
210 DataType::FixedSizeBinary(size) => proto::data_type::Kind::FixedBinary(*size),
211 DataType::List(inner) => proto::data_type::Kind::List(Box::new(inner.into_proto())),
212 DataType::Map(inner, sorted) => {
213 let map = proto::data_type::Map {
214 value: Some(Box::new(inner.into_proto())),
215 sorted: *sorted,
216 };
217 proto::data_type::Kind::Map(Box::new(map))
218 }
219 DataType::Struct(children) => {
220 let children = children.into_iter().map(|f| f.into_proto()).collect();
221 proto::data_type::Kind::Struct(proto::data_type::Struct { children })
222 }
223 other => unimplemented!("unsupported data type {other:?}"),
224 };
225
226 proto::DataType { kind: Some(kind) }
227 }
228
229 fn from_proto(proto: proto::DataType) -> Result<Self, TryFromProtoError> {
230 let data_type = proto
231 .kind
232 .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
233 let data_type = match data_type {
234 proto::data_type::Kind::Null(()) => DataType::Null,
235 proto::data_type::Kind::Boolean(()) => DataType::Boolean,
236 proto::data_type::Kind::Uint8(()) => DataType::UInt8,
237 proto::data_type::Kind::Uint16(()) => DataType::UInt16,
238 proto::data_type::Kind::Uint32(()) => DataType::UInt32,
239 proto::data_type::Kind::Uint64(()) => DataType::UInt64,
240 proto::data_type::Kind::Int8(()) => DataType::Int8,
241 proto::data_type::Kind::Int16(()) => DataType::Int16,
242 proto::data_type::Kind::Int32(()) => DataType::Int32,
243 proto::data_type::Kind::Int64(()) => DataType::Int64,
244 proto::data_type::Kind::Float32(()) => DataType::Float32,
245 proto::data_type::Kind::Float64(()) => DataType::Float64,
246 proto::data_type::Kind::String(()) => DataType::Utf8,
247 proto::data_type::Kind::Binary(()) => DataType::Binary,
248 proto::data_type::Kind::FixedBinary(size) => DataType::FixedSizeBinary(size),
249 proto::data_type::Kind::List(inner) => DataType::List(Arc::new((*inner).into_rust()?)),
250 proto::data_type::Kind::Map(inner) => {
251 let value = inner
252 .value
253 .ok_or_else(|| TryFromProtoError::missing_field("map.value"))?;
254 DataType::Map(Arc::new((*value).into_rust()?), inner.sorted)
255 }
256 proto::data_type::Kind::Struct(inner) => {
257 let children: Vec<Field> = inner
258 .children
259 .into_iter()
260 .map(|c| c.into_rust())
261 .collect::<Result<_, _>>()?;
262 DataType::Struct(Fields::from(children))
263 }
264 };
265
266 Ok(data_type)
267 }
268}
269
270impl RustType<proto::Field> for arrow::datatypes::Field {
271 fn into_proto(&self) -> proto::Field {
272 proto::Field {
273 name: self.name().clone(),
274 nullable: self.is_nullable(),
275 data_type: Some(Box::new(self.data_type().into_proto())),
276 }
277 }
278
279 fn from_proto(proto: proto::Field) -> Result<Self, TryFromProtoError> {
280 let proto::Field {
281 name,
282 nullable,
283 data_type,
284 } = proto;
285 let data_type =
286 data_type.ok_or_else(|| TryFromProtoError::missing_field("field.data_type"))?;
287 let data_type = (*data_type).into_rust()?;
288
289 Ok(Field::new(name, data_type, nullable))
290 }
291}
292
293impl RustType<proto::Buffer> for arrow::buffer::Buffer {
294 fn into_proto(&self) -> proto::Buffer {
295 proto::Buffer {
297 data: bytes::Bytes::copy_from_slice(self.as_slice()),
298 }
299 }
300
301 fn from_proto(proto: proto::Buffer) -> Result<Self, TryFromProtoError> {
302 Ok(arrow::buffer::Buffer::from_bytes(proto.data.into()))
303 }
304}
305
306impl RustType<proto::BooleanBuffer> for arrow::buffer::BooleanBuffer {
307 fn into_proto(&self) -> proto::BooleanBuffer {
308 proto::BooleanBuffer {
309 buffer: Some(self.sliced().into_proto()),
310 length: u64::cast_from(self.len()),
311 }
312 }
313
314 fn from_proto(proto: proto::BooleanBuffer) -> Result<Self, TryFromProtoError> {
315 let proto::BooleanBuffer { buffer, length } = proto;
316 let buffer = buffer.into_rust_if_some("buffer")?;
317 Ok(BooleanBuffer::new(buffer, 0, usize::cast_from(length)))
318 }
319}
320
321#[derive(Clone, Debug)]
323pub enum ArrayOrd {
324 Null(NullArray),
326 Bool(BooleanArray),
328 Int8(Int8Array),
330 Int16(Int16Array),
332 Int32(Int32Array),
334 Int64(Int64Array),
336 UInt8(UInt8Array),
338 UInt16(UInt16Array),
340 UInt32(UInt32Array),
342 UInt64(UInt64Array),
344 Float32(Float32Array),
346 Float64(Float64Array),
348 String(StringArray),
350 Binary(BinaryArray),
352 FixedSizeBinary(FixedSizeBinaryArray),
354 List(Option<NullBuffer>, OffsetBuffer<i32>, Box<ArrayOrd>),
356 Struct(Option<NullBuffer>, Vec<ArrayOrd>),
358}
359
360impl ArrayOrd {
361 pub fn new(array: &dyn Array) -> Self {
363 match array.data_type() {
364 DataType::Null => ArrayOrd::Null(NullArray::from(array.to_data())),
365 DataType::Boolean => ArrayOrd::Bool(array.as_boolean().clone()),
366 DataType::Int8 => ArrayOrd::Int8(array.as_primitive().clone()),
367 DataType::Int16 => ArrayOrd::Int16(array.as_primitive().clone()),
368 DataType::Int32 => ArrayOrd::Int32(array.as_primitive().clone()),
369 DataType::Int64 => ArrayOrd::Int64(array.as_primitive().clone()),
370 DataType::UInt8 => ArrayOrd::UInt8(array.as_primitive().clone()),
371 DataType::UInt16 => ArrayOrd::UInt16(array.as_primitive().clone()),
372 DataType::UInt32 => ArrayOrd::UInt32(array.as_primitive().clone()),
373 DataType::UInt64 => ArrayOrd::UInt64(array.as_primitive().clone()),
374 DataType::Float32 => ArrayOrd::Float32(array.as_primitive().clone()),
375 DataType::Float64 => ArrayOrd::Float64(array.as_primitive().clone()),
376 DataType::Binary => ArrayOrd::Binary(array.as_binary().clone()),
377 DataType::Utf8 => ArrayOrd::String(array.as_string().clone()),
378 DataType::FixedSizeBinary(_) => {
379 ArrayOrd::FixedSizeBinary(array.as_fixed_size_binary().clone())
380 }
381 DataType::List(_) => {
382 let list_array = array.as_list();
383 ArrayOrd::List(
384 list_array.nulls().cloned(),
385 list_array.offsets().clone(),
386 Box::new(ArrayOrd::new(list_array.values())),
387 )
388 }
389 DataType::Struct(_) => {
390 let struct_array = array.as_struct();
391 let nulls = array.nulls().cloned();
392 let columns: Vec<_> = struct_array
393 .columns()
394 .iter()
395 .map(|a| ArrayOrd::new(a))
396 .collect();
397 ArrayOrd::Struct(nulls, columns)
398 }
399 data_type => unimplemented!("array type {data_type:?} not yet supported"),
400 }
401 }
402
403 pub fn goodbytes(&self) -> usize {
406 match self {
407 ArrayOrd::Null(_) => 0,
408 ArrayOrd::Bool(b) => b.len(),
411 ArrayOrd::Int8(a) => a.values().inner().len(),
412 ArrayOrd::Int16(a) => a.values().inner().len(),
413 ArrayOrd::Int32(a) => a.values().inner().len(),
414 ArrayOrd::Int64(a) => a.values().inner().len(),
415 ArrayOrd::UInt8(a) => a.values().inner().len(),
416 ArrayOrd::UInt16(a) => a.values().inner().len(),
417 ArrayOrd::UInt32(a) => a.values().inner().len(),
418 ArrayOrd::UInt64(a) => a.values().inner().len(),
419 ArrayOrd::Float32(a) => a.values().inner().len(),
420 ArrayOrd::Float64(a) => a.values().inner().len(),
421 ArrayOrd::String(a) => a.values().len(),
422 ArrayOrd::Binary(a) => a.values().len(),
423 ArrayOrd::FixedSizeBinary(a) => a.values().len(),
424 ArrayOrd::List(_, _, nested) => nested.goodbytes(),
425 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.goodbytes()).sum(),
426 }
427 }
428
429 pub fn at(&self, idx: usize) -> ArrayIdx {
431 ArrayIdx { idx, array: self }
432 }
433}
434
435#[derive(Clone, Copy, Debug)]
443pub struct ArrayIdx<'a> {
444 pub idx: usize,
446 pub array: &'a ArrayOrd,
448}
449
450impl Display for ArrayIdx<'_> {
451 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452 match self.array {
453 ArrayOrd::Null(_) => write!(f, "null"),
454 ArrayOrd::Bool(a) => write!(f, "{}", a.value(self.idx)),
455 ArrayOrd::Int8(a) => write!(f, "{}", a.value(self.idx)),
456 ArrayOrd::Int16(a) => write!(f, "{}", a.value(self.idx)),
457 ArrayOrd::Int32(a) => write!(f, "{}", a.value(self.idx)),
458 ArrayOrd::Int64(a) => write!(f, "{}", a.value(self.idx)),
459 ArrayOrd::UInt8(a) => write!(f, "{}", a.value(self.idx)),
460 ArrayOrd::UInt16(a) => write!(f, "{}", a.value(self.idx)),
461 ArrayOrd::UInt32(a) => write!(f, "{}", a.value(self.idx)),
462 ArrayOrd::UInt64(a) => write!(f, "{}", a.value(self.idx)),
463 ArrayOrd::Float32(a) => write!(f, "{}", a.value(self.idx)),
464 ArrayOrd::Float64(a) => write!(f, "{}", a.value(self.idx)),
465 ArrayOrd::String(a) => write!(f, "{}", a.value(self.idx)),
466 ArrayOrd::Binary(a) => {
467 for byte in a.value(self.idx) {
468 write!(f, "{:02x}", byte)?;
469 }
470 Ok(())
471 }
472 ArrayOrd::FixedSizeBinary(a) => {
473 for byte in a.value(self.idx) {
474 write!(f, "{:02x}", byte)?;
475 }
476 Ok(())
477 }
478 ArrayOrd::List(_, offsets, nested) => {
479 write!(
480 f,
481 "[{}]",
482 mz_ore::str::separated(", ", list_range(offsets, nested, self.idx))
483 )
484 }
485 ArrayOrd::Struct(_, nested) => write!(
486 f,
487 "{{{}}}",
488 mz_ore::str::separated(", ", nested.iter().map(|f| f.at(self.idx)))
489 ),
490 }
491 }
492}
493
494#[inline]
495fn list_range<'a>(
496 offsets: &OffsetBuffer<i32>,
497 values: &'a ArrayOrd,
498 idx: usize,
499) -> impl Iterator<Item = ArrayIdx<'a>> + Clone {
500 let offsets = offsets.inner();
501 let from = offsets[idx].as_usize();
502 let to = offsets[idx + 1].as_usize();
503 (from..to).map(|i| values.at(i))
504}
505
506impl<'a> ArrayIdx<'a> {
507 pub fn goodbytes(&self) -> usize {
510 match self.array {
511 ArrayOrd::Null(_) => 0,
512 ArrayOrd::Bool(_) => size_of::<bool>(),
513 ArrayOrd::Int8(_) => size_of::<i8>(),
514 ArrayOrd::Int16(_) => size_of::<i16>(),
515 ArrayOrd::Int32(_) => size_of::<i32>(),
516 ArrayOrd::Int64(_) => size_of::<i64>(),
517 ArrayOrd::UInt8(_) => size_of::<u8>(),
518 ArrayOrd::UInt16(_) => size_of::<u16>(),
519 ArrayOrd::UInt32(_) => size_of::<u32>(),
520 ArrayOrd::UInt64(_) => size_of::<u64>(),
521 ArrayOrd::Float32(_) => size_of::<f32>(),
522 ArrayOrd::Float64(_) => size_of::<f64>(),
523 ArrayOrd::String(a) => a.value(self.idx).len(),
524 ArrayOrd::Binary(a) => a.value(self.idx).len(),
525 ArrayOrd::FixedSizeBinary(a) => a.value_length().as_usize(),
526 ArrayOrd::List(_, offsets, nested) => {
527 list_range(offsets, nested, self.idx)
529 .map(|a| a.goodbytes())
530 .sum()
531 }
532 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.at(self.idx).goodbytes()).sum(),
533 }
534 }
535}
536
537impl<'a> Ord for ArrayIdx<'a> {
538 fn cmp(&self, other: &Self) -> Ordering {
539 #[inline]
540 fn is_null(buffer: &Option<NullBuffer>, idx: usize) -> bool {
541 buffer.as_ref().map_or(false, |b| b.is_null(idx))
542 }
543 #[inline]
544 fn cmp<A: ArrayAccessor>(
545 left: A,
546 left_idx: usize,
547 right: A,
548 right_idx: usize,
549 cmp: fn(&A::Item, &A::Item) -> Ordering,
550 ) -> Ordering {
551 match (left.is_null(left_idx), right.is_null(right_idx)) {
553 (false, true) => Ordering::Less,
554 (true, true) => Ordering::Equal,
555 (true, false) => Ordering::Greater,
556 (false, false) => cmp(&left.value(left_idx), &right.value(right_idx)),
557 }
558 }
559 match (&self.array, &other.array) {
560 (ArrayOrd::Null(s), ArrayOrd::Null(o)) => {
561 debug_assert!(
562 self.idx < s.len() && other.idx < o.len(),
563 "null array indices in bounds"
564 );
565 Ordering::Equal
566 }
567 (ArrayOrd::Bool(s), ArrayOrd::Bool(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
569 (ArrayOrd::Int8(s), ArrayOrd::Int8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
570 (ArrayOrd::Int16(s), ArrayOrd::Int16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
571 (ArrayOrd::Int32(s), ArrayOrd::Int32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
572 (ArrayOrd::Int64(s), ArrayOrd::Int64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
573 (ArrayOrd::UInt8(s), ArrayOrd::UInt8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
574 (ArrayOrd::UInt16(s), ArrayOrd::UInt16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
575 (ArrayOrd::UInt32(s), ArrayOrd::UInt32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
576 (ArrayOrd::UInt64(s), ArrayOrd::UInt64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
577 (ArrayOrd::Float32(s), ArrayOrd::Float32(o)) => {
578 cmp(s, self.idx, o, other.idx, f32::total_cmp)
579 }
580 (ArrayOrd::Float64(s), ArrayOrd::Float64(o)) => {
581 cmp(s, self.idx, o, other.idx, f64::total_cmp)
582 }
583 (ArrayOrd::String(s), ArrayOrd::String(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
584 (ArrayOrd::Binary(s), ArrayOrd::Binary(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
585 (ArrayOrd::FixedSizeBinary(s), ArrayOrd::FixedSizeBinary(o)) => {
586 cmp(s, self.idx, o, other.idx, Ord::cmp)
587 }
588 (
591 ArrayOrd::List(s_nulls, s_offset, s_values),
592 ArrayOrd::List(o_nulls, o_offset, o_values),
593 ) => match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
594 (false, true) => Ordering::Less,
595 (true, true) => Ordering::Equal,
596 (true, false) => Ordering::Greater,
597 (false, false) => list_range(s_offset, s_values, self.idx)
598 .cmp(list_range(o_offset, o_values, other.idx)),
599 },
600 (ArrayOrd::Struct(s_nulls, s_cols), ArrayOrd::Struct(o_nulls, o_cols)) => {
603 match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
604 (false, true) => Ordering::Less,
605 (true, true) => Ordering::Equal,
606 (true, false) => Ordering::Greater,
607 (false, false) => {
608 let s = s_cols.iter().map(|array| array.at(self.idx));
609 let o = o_cols.iter().map(|array| array.at(other.idx));
610 s.cmp(o)
611 }
612 }
613 }
614 (_, _) => panic!("array types did not match"),
615 }
616 }
617}
618
619impl<'a> PartialOrd for ArrayIdx<'a> {
620 fn partial_cmp(&self, other: &ArrayIdx) -> Option<Ordering> {
621 Some(self.cmp(other))
622 }
623}
624
625impl<'a> PartialEq for ArrayIdx<'a> {
626 fn eq(&self, other: &ArrayIdx) -> bool {
627 self.cmp(other) == Ordering::Equal
628 }
629}
630
631impl<'a> Eq for ArrayIdx<'a> {}
632
633#[derive(Debug, Clone)]
635pub struct ArrayBound {
636 raw: ArrayRef,
637 ord: ArrayOrd,
638 index: usize,
639}
640
641impl PartialEq for ArrayBound {
642 fn eq(&self, other: &Self) -> bool {
643 self.get().eq(&other.get())
644 }
645}
646
647impl Eq for ArrayBound {}
648
649impl ArrayBound {
650 pub fn new(array: ArrayRef, index: usize) -> Self {
652 Self {
653 ord: ArrayOrd::new(array.as_ref()),
654 raw: array,
655 index,
656 }
657 }
658
659 pub fn get(&self) -> ArrayIdx {
661 self.ord.at(self.index)
662 }
663
664 pub fn to_proto_lower(&self, max_len: usize) -> Option<ProtoArrayData> {
668 let indices = UInt64Array::from_value(u64::usize_as(self.index), 1);
671 let taken = arrow::compute::take(self.raw.as_ref(), &indices, None).ok()?;
672 let array_data = taken.into_data();
673
674 let mut proto = array_data.into_proto();
675 let original_len = proto.encoded_len();
676 if original_len <= max_len {
677 return Some(proto);
678 }
679
680 let mut data_type = proto.data_type.take()?;
681 maybe_trim_proto(&mut data_type, &mut proto, max_len);
682 proto.data_type = Some(data_type);
683
684 if cfg!(debug_assertions) {
685 let array: ArrayData = proto
686 .clone()
687 .into_rust()
688 .expect("trimmed array data can still be decoded");
689 assert_eq!(array.len(), 1);
690 let new_bound = Self::new(make_array(array), 0);
691 assert!(
692 new_bound.get() <= self.get(),
693 "trimmed bound should be comparable to and no larger than the original data"
694 )
695 }
696
697 if proto.encoded_len() <= max_len {
698 Some(proto)
699 } else {
700 None
701 }
702 }
703}
704
705fn maybe_trim_proto(data_type: &mut proto::DataType, body: &mut ProtoArrayData, max_len: usize) {
713 assert!(body.data_type.is_none(), "expected separate data type");
714 let encoded_len = data_type.encoded_len() + body.encoded_len();
716 match &mut data_type.kind {
717 Some(data_type::Kind::Struct(data_type::Struct { children: fields })) => {
718 let mut struct_len = encoded_len;
720 while struct_len > max_len {
721 let Some(mut child) = body.children.pop() else {
722 break;
723 };
724 let Some(mut field) = fields.pop() else { break };
725
726 struct_len -= field.encoded_len() + child.encoded_len();
727 if let Some(remaining_len) = max_len.checked_sub(struct_len) {
728 let Some(field_type) = field.data_type.as_mut() else {
731 break;
732 };
733 maybe_trim_proto(field_type, &mut child, remaining_len);
734 if field.encoded_len() + child.encoded_len() <= remaining_len {
735 fields.push(field);
736 body.children.push(child);
737 }
738 break;
739 }
740 }
741 }
742 _ => {}
743 };
744}
745
746#[cfg(test)]
747mod tests {
748 use crate::arrow::{ArrayBound, ArrayOrd};
749 use arrow::array::{
750 ArrayRef, AsArray, BooleanArray, StringArray, StructArray, UInt64Array, make_array,
751 };
752 use arrow::datatypes::{DataType, Field, Fields};
753 use mz_ore::assert_none;
754 use mz_proto::ProtoType;
755 use std::sync::Arc;
756
757 #[mz_ore::test]
758 fn trim_proto() {
759 let nested_fields: Fields = vec![Field::new("a", DataType::UInt64, true)].into();
760 let array: ArrayRef = Arc::new(StructArray::new(
761 vec![
762 Field::new("a", DataType::UInt64, true),
763 Field::new("b", DataType::Utf8, true),
764 Field::new_struct("c", nested_fields.clone(), true),
765 ]
766 .into(),
767 vec![
768 Arc::new(UInt64Array::from_iter_values([1])),
769 Arc::new(StringArray::from_iter_values(["large".repeat(50)])),
770 Arc::new(StructArray::new_null(nested_fields, 1)),
771 ],
772 None,
773 ));
774 let bound = ArrayBound::new(array, 0);
775
776 assert_none!(bound.to_proto_lower(0));
777 assert_none!(bound.to_proto_lower(1));
778
779 let proto = bound
780 .to_proto_lower(100)
781 .expect("can fit something in less than 100 bytes");
782 let array = make_array(proto.into_rust().expect("valid proto"));
783 assert_eq!(
784 array.as_struct().column_names().as_slice(),
785 &["a"],
786 "only the first column should fit"
787 );
788
789 let proto = bound
790 .to_proto_lower(1000)
791 .expect("can fit everything in less than 1000 bytes");
792 let array = make_array(proto.into_rust().expect("valid proto"));
793 assert_eq!(
794 array.as_struct().column_names().as_slice(),
795 &["a", "b", "c"],
796 "all columns should fit"
797 )
798 }
799
800 #[mz_ore::test]
801 fn struct_ord() {
802 let prefix = StructArray::new(
803 vec![Field::new("a", DataType::UInt64, true)].into(),
804 vec![Arc::new(UInt64Array::from_iter_values([1, 3, 5]))],
805 None,
806 );
807 let full = StructArray::new(
808 vec![
809 Field::new("a", DataType::UInt64, true),
810 Field::new("b", DataType::Utf8, true),
811 ]
812 .into(),
813 vec![
814 Arc::new(UInt64Array::from_iter_values([2, 3, 4])),
815 Arc::new(StringArray::from_iter_values(["a", "b", "c"])),
816 ],
817 None,
818 );
819 let prefix_ord = ArrayOrd::new(&prefix);
820 let full_ord = ArrayOrd::new(&full);
821
822 assert!(prefix_ord.at(0) < full_ord.at(0), "(1) < (2, 'a')");
825 assert!(prefix_ord.at(1) < full_ord.at(1), "(3) < (3, 'b')");
826 assert!(prefix_ord.at(2) > full_ord.at(2), "(5) < (4, 'c')");
827 }
828
829 #[mz_ore::test]
830 #[should_panic(expected = "array types did not match")]
831 fn struct_ord_incompat() {
832 let string = StructArray::new(
835 vec![
836 Field::new("a", DataType::UInt64, true),
837 Field::new("b", DataType::Utf8, true),
838 ]
839 .into(),
840 vec![
841 Arc::new(UInt64Array::from_iter_values([1])),
842 Arc::new(StringArray::from_iter_values(["a"])),
843 ],
844 None,
845 );
846 let boolean = StructArray::new(
847 vec![
848 Field::new("a", DataType::UInt64, true),
849 Field::new("b", DataType::Boolean, true),
850 ]
851 .into(),
852 vec![
853 Arc::new(UInt64Array::from_iter_values([1])),
854 Arc::new(BooleanArray::from_iter([Some(true)])),
855 ],
856 None,
857 );
858 let string_ord = ArrayOrd::new(&string);
859 let bool_ord = ArrayOrd::new(&boolean);
860
861 assert!(string_ord.at(0) < bool_ord.at(0));
863 }
864}