1use std::cmp::Ordering;
26use std::fmt::Display;
27use std::sync::Arc;
28
29use arrow::array::*;
30use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
31use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Fields};
32use itertools::Itertools;
33use mz_ore::cast::CastFrom;
34use mz_ore::soft_assert_eq_no_log;
35use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
36use prost::Message;
37
38#[allow(missing_docs)]
39mod proto {
40 include!(concat!(env!("OUT_DIR"), "/mz_persist_types.arrow.rs"));
41}
42use crate::arrow::proto::data_type;
43pub use proto::{DataType as ProtoDataType, ProtoArrayData};
44
45pub fn fields_for_type(data_type: &DataType) -> &[FieldRef] {
47 match data_type {
48 DataType::Struct(fields) => fields,
49 DataType::List(field) => std::slice::from_ref(field),
50 DataType::Map(field, _) => std::slice::from_ref(field),
51 DataType::Null
52 | DataType::Boolean
53 | DataType::Int8
54 | DataType::Int16
55 | DataType::Int32
56 | DataType::Int64
57 | DataType::UInt8
58 | DataType::UInt16
59 | DataType::UInt32
60 | DataType::UInt64
61 | DataType::Float16
62 | DataType::Float32
63 | DataType::Float64
64 | DataType::Timestamp(_, _)
65 | DataType::Date32
66 | DataType::Date64
67 | DataType::Time32(_)
68 | DataType::Time64(_)
69 | DataType::Duration(_)
70 | DataType::Interval(_)
71 | DataType::Binary
72 | DataType::FixedSizeBinary(_)
73 | DataType::LargeBinary
74 | DataType::BinaryView
75 | DataType::Utf8
76 | DataType::LargeUtf8
77 | DataType::Utf8View
78 | DataType::Decimal128(_, _)
79 | DataType::Decimal256(_, _) => &[],
80 DataType::ListView(_)
81 | DataType::FixedSizeList(_, _)
82 | DataType::LargeList(_)
83 | DataType::LargeListView(_)
84 | DataType::Union(_, _)
85 | DataType::Dictionary(_, _)
86 | DataType::RunEndEncoded(_, _) => unimplemented!("not supported"),
87 }
88}
89
90fn into_proto_with_type(data: &ArrayData, expected_type: Option<&DataType>) -> ProtoArrayData {
93 let data_type = match expected_type {
94 Some(expected) => {
95 soft_assert_eq_no_log!(
98 expected,
99 data.data_type(),
100 "actual type should match expected type"
101 );
102 None
103 }
104 None => Some(data.data_type().into_proto()),
105 };
106
107 ProtoArrayData {
108 data_type,
109 length: u64::cast_from(data.len()),
110 offset: u64::cast_from(data.offset()),
111 buffers: data.buffers().iter().map(|b| b.into_proto()).collect(),
112 children: data
113 .child_data()
114 .iter()
115 .zip_eq(fields_for_type(
116 expected_type.unwrap_or_else(|| data.data_type()),
117 ))
118 .map(|(child, expect)| into_proto_with_type(child, Some(expect.data_type())))
119 .collect(),
120 nulls: data.nulls().map(|n| n.inner().into_proto()),
121 }
122}
123
124fn from_proto_with_type(
127 proto: ProtoArrayData,
128 expected_type: Option<&DataType>,
129) -> Result<ArrayData, TryFromProtoError> {
130 let ProtoArrayData {
131 data_type,
132 length,
133 offset,
134 buffers,
135 children,
136 nulls,
137 } = proto;
138 let data_type: Option<DataType> = data_type.into_rust()?;
139 let data_type = match (data_type, expected_type) {
140 (Some(data_type), None) => data_type,
141 (Some(data_type), Some(expected_type)) => {
142 soft_assert_eq_no_log!(
145 data_type,
146 *expected_type,
147 "expected type should match actual type"
148 );
149 data_type
150 }
151 (None, Some(expected_type)) => expected_type.clone(),
152 (None, None) => {
153 return Err(TryFromProtoError::MissingField(
154 "ProtoArrayData::data_type".to_string(),
155 ));
156 }
157 };
158 let nulls = nulls
159 .map(|n| n.into_rust())
160 .transpose()?
161 .map(NullBuffer::new);
162
163 let mut builder = ArrayDataBuilder::new(data_type.clone())
164 .len(usize::cast_from(length))
165 .offset(usize::cast_from(offset))
166 .nulls(nulls);
167
168 for b in buffers.into_iter().map(|b| b.into_rust()) {
169 builder = builder.add_buffer(b?);
170 }
171 for c in children
172 .into_iter()
173 .zip_eq(fields_for_type(&data_type))
174 .map(|(c, field)| from_proto_with_type(c, Some(field.data_type())))
175 {
176 builder = builder.add_child_data(c?);
177 }
178
179 builder
181 .align_buffers(true)
182 .build()
183 .map_err(|e| TryFromProtoError::RowConversionError(e.to_string()))
184}
185
186impl RustType<ProtoArrayData> for arrow::array::ArrayData {
187 fn into_proto(&self) -> ProtoArrayData {
188 into_proto_with_type(self, None)
189 }
190
191 fn from_proto(proto: ProtoArrayData) -> Result<Self, TryFromProtoError> {
192 from_proto_with_type(proto, None)
193 }
194}
195
196impl RustType<proto::DataType> for arrow::datatypes::DataType {
197 fn into_proto(&self) -> proto::DataType {
198 let kind = match self {
199 DataType::Null => proto::data_type::Kind::Null(()),
200 DataType::Boolean => proto::data_type::Kind::Boolean(()),
201 DataType::UInt8 => proto::data_type::Kind::Uint8(()),
202 DataType::UInt16 => proto::data_type::Kind::Uint16(()),
203 DataType::UInt32 => proto::data_type::Kind::Uint32(()),
204 DataType::UInt64 => proto::data_type::Kind::Uint64(()),
205 DataType::Int8 => proto::data_type::Kind::Int8(()),
206 DataType::Int16 => proto::data_type::Kind::Int16(()),
207 DataType::Int32 => proto::data_type::Kind::Int32(()),
208 DataType::Int64 => proto::data_type::Kind::Int64(()),
209 DataType::Float32 => proto::data_type::Kind::Float32(()),
210 DataType::Float64 => proto::data_type::Kind::Float64(()),
211 DataType::Utf8 => proto::data_type::Kind::String(()),
212 DataType::Binary => proto::data_type::Kind::Binary(()),
213 DataType::FixedSizeBinary(size) => proto::data_type::Kind::FixedBinary(*size),
214 DataType::List(inner) => proto::data_type::Kind::List(Box::new(inner.into_proto())),
215 DataType::Map(inner, sorted) => {
216 let map = proto::data_type::Map {
217 value: Some(Box::new(inner.into_proto())),
218 sorted: *sorted,
219 };
220 proto::data_type::Kind::Map(Box::new(map))
221 }
222 DataType::Struct(children) => {
223 let children = children.into_iter().map(|f| f.into_proto()).collect();
224 proto::data_type::Kind::Struct(proto::data_type::Struct { children })
225 }
226 other => unimplemented!("unsupported data type {other:?}"),
227 };
228
229 proto::DataType { kind: Some(kind) }
230 }
231
232 fn from_proto(proto: proto::DataType) -> Result<Self, TryFromProtoError> {
233 let data_type = proto
234 .kind
235 .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
236 let data_type = match data_type {
237 proto::data_type::Kind::Null(()) => DataType::Null,
238 proto::data_type::Kind::Boolean(()) => DataType::Boolean,
239 proto::data_type::Kind::Uint8(()) => DataType::UInt8,
240 proto::data_type::Kind::Uint16(()) => DataType::UInt16,
241 proto::data_type::Kind::Uint32(()) => DataType::UInt32,
242 proto::data_type::Kind::Uint64(()) => DataType::UInt64,
243 proto::data_type::Kind::Int8(()) => DataType::Int8,
244 proto::data_type::Kind::Int16(()) => DataType::Int16,
245 proto::data_type::Kind::Int32(()) => DataType::Int32,
246 proto::data_type::Kind::Int64(()) => DataType::Int64,
247 proto::data_type::Kind::Float32(()) => DataType::Float32,
248 proto::data_type::Kind::Float64(()) => DataType::Float64,
249 proto::data_type::Kind::String(()) => DataType::Utf8,
250 proto::data_type::Kind::Binary(()) => DataType::Binary,
251 proto::data_type::Kind::FixedBinary(size) => DataType::FixedSizeBinary(size),
252 proto::data_type::Kind::List(inner) => DataType::List(Arc::new((*inner).into_rust()?)),
253 proto::data_type::Kind::Map(inner) => {
254 let value = inner
255 .value
256 .ok_or_else(|| TryFromProtoError::missing_field("map.value"))?;
257 DataType::Map(Arc::new((*value).into_rust()?), inner.sorted)
258 }
259 proto::data_type::Kind::Struct(inner) => {
260 let children: Vec<Field> = inner
261 .children
262 .into_iter()
263 .map(|c| c.into_rust())
264 .collect::<Result<_, _>>()?;
265 DataType::Struct(Fields::from(children))
266 }
267 };
268
269 Ok(data_type)
270 }
271}
272
273impl RustType<proto::Field> for arrow::datatypes::Field {
274 fn into_proto(&self) -> proto::Field {
275 proto::Field {
276 name: self.name().clone(),
277 nullable: self.is_nullable(),
278 data_type: Some(Box::new(self.data_type().into_proto())),
279 }
280 }
281
282 fn from_proto(proto: proto::Field) -> Result<Self, TryFromProtoError> {
283 let proto::Field {
284 name,
285 nullable,
286 data_type,
287 } = proto;
288 let data_type =
289 data_type.ok_or_else(|| TryFromProtoError::missing_field("field.data_type"))?;
290 let data_type = (*data_type).into_rust()?;
291
292 Ok(Field::new(name, data_type, nullable))
293 }
294}
295
296impl RustType<proto::Buffer> for arrow::buffer::Buffer {
297 fn into_proto(&self) -> proto::Buffer {
298 #[repr(transparent)]
300 struct BufferWrapper(arrow::buffer::Buffer);
301 impl AsRef<[u8]> for BufferWrapper {
302 fn as_ref(&self) -> &[u8] {
303 &*self.0
304 }
305 }
306 proto::Buffer {
307 data: bytes::Bytes::from_owner(BufferWrapper(self.clone())),
308 }
309 }
310
311 fn from_proto(proto: proto::Buffer) -> Result<Self, TryFromProtoError> {
312 Ok(arrow::buffer::Buffer::from(proto.data))
313 }
314}
315
316impl RustType<proto::BooleanBuffer> for arrow::buffer::BooleanBuffer {
317 fn into_proto(&self) -> proto::BooleanBuffer {
318 proto::BooleanBuffer {
319 buffer: Some(self.sliced().into_proto()),
320 length: u64::cast_from(self.len()),
321 }
322 }
323
324 fn from_proto(proto: proto::BooleanBuffer) -> Result<Self, TryFromProtoError> {
325 let proto::BooleanBuffer { buffer, length } = proto;
326 let buffer = buffer.into_rust_if_some("buffer")?;
327 Ok(BooleanBuffer::new(buffer, 0, usize::cast_from(length)))
328 }
329}
330
331#[derive(Clone, Debug)]
333pub enum ArrayOrd {
334 Null(NullArray),
336 Bool(BooleanArray),
338 Int8(Int8Array),
340 Int16(Int16Array),
342 Int32(Int32Array),
344 Int64(Int64Array),
346 UInt8(UInt8Array),
348 UInt16(UInt16Array),
350 UInt32(UInt32Array),
352 UInt64(UInt64Array),
354 Float32(Float32Array),
356 Float64(Float64Array),
358 String(StringArray),
360 Binary(BinaryArray),
362 FixedSizeBinary(FixedSizeBinaryArray),
364 List(Option<NullBuffer>, OffsetBuffer<i32>, Box<ArrayOrd>),
366 Struct(Option<NullBuffer>, Vec<ArrayOrd>),
368}
369
370impl ArrayOrd {
371 pub fn new(array: &dyn Array) -> Self {
373 match array.data_type() {
374 DataType::Null => ArrayOrd::Null(NullArray::from(array.to_data())),
375 DataType::Boolean => ArrayOrd::Bool(array.as_boolean().clone()),
376 DataType::Int8 => ArrayOrd::Int8(array.as_primitive().clone()),
377 DataType::Int16 => ArrayOrd::Int16(array.as_primitive().clone()),
378 DataType::Int32 => ArrayOrd::Int32(array.as_primitive().clone()),
379 DataType::Int64 => ArrayOrd::Int64(array.as_primitive().clone()),
380 DataType::UInt8 => ArrayOrd::UInt8(array.as_primitive().clone()),
381 DataType::UInt16 => ArrayOrd::UInt16(array.as_primitive().clone()),
382 DataType::UInt32 => ArrayOrd::UInt32(array.as_primitive().clone()),
383 DataType::UInt64 => ArrayOrd::UInt64(array.as_primitive().clone()),
384 DataType::Float32 => ArrayOrd::Float32(array.as_primitive().clone()),
385 DataType::Float64 => ArrayOrd::Float64(array.as_primitive().clone()),
386 DataType::Binary => ArrayOrd::Binary(array.as_binary().clone()),
387 DataType::Utf8 => ArrayOrd::String(array.as_string().clone()),
388 DataType::FixedSizeBinary(_) => {
389 ArrayOrd::FixedSizeBinary(array.as_fixed_size_binary().clone())
390 }
391 DataType::List(_) => {
392 let list_array = array.as_list();
393 ArrayOrd::List(
394 list_array.nulls().cloned(),
395 list_array.offsets().clone(),
396 Box::new(ArrayOrd::new(list_array.values())),
397 )
398 }
399 DataType::Struct(_) => {
400 let struct_array = array.as_struct();
401 let nulls = array.nulls().cloned();
402 let columns: Vec<_> = struct_array
403 .columns()
404 .iter()
405 .map(|a| ArrayOrd::new(a))
406 .collect();
407 ArrayOrd::Struct(nulls, columns)
408 }
409 data_type => unimplemented!("array type {data_type:?} not yet supported"),
410 }
411 }
412
413 pub fn goodbytes(&self) -> usize {
416 match self {
417 ArrayOrd::Null(_) => 0,
418 ArrayOrd::Bool(b) => b.len(),
421 ArrayOrd::Int8(a) => a.values().inner().len(),
422 ArrayOrd::Int16(a) => a.values().inner().len(),
423 ArrayOrd::Int32(a) => a.values().inner().len(),
424 ArrayOrd::Int64(a) => a.values().inner().len(),
425 ArrayOrd::UInt8(a) => a.values().inner().len(),
426 ArrayOrd::UInt16(a) => a.values().inner().len(),
427 ArrayOrd::UInt32(a) => a.values().inner().len(),
428 ArrayOrd::UInt64(a) => a.values().inner().len(),
429 ArrayOrd::Float32(a) => a.values().inner().len(),
430 ArrayOrd::Float64(a) => a.values().inner().len(),
431 ArrayOrd::String(a) => a.values().len(),
432 ArrayOrd::Binary(a) => a.values().len(),
433 ArrayOrd::FixedSizeBinary(a) => a.values().len(),
434 ArrayOrd::List(_, _, nested) => nested.goodbytes(),
435 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.goodbytes()).sum(),
436 }
437 }
438
439 pub fn at(&self, idx: usize) -> ArrayIdx<'_> {
441 ArrayIdx { idx, array: self }
442 }
443}
444
445#[derive(Clone, Copy, Debug)]
453pub struct ArrayIdx<'a> {
454 pub idx: usize,
456 pub array: &'a ArrayOrd,
458}
459
460impl Display for ArrayIdx<'_> {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 match self.array {
463 ArrayOrd::Null(_) => write!(f, "null"),
464 ArrayOrd::Bool(a) => write!(f, "{}", a.value(self.idx)),
465 ArrayOrd::Int8(a) => write!(f, "{}", a.value(self.idx)),
466 ArrayOrd::Int16(a) => write!(f, "{}", a.value(self.idx)),
467 ArrayOrd::Int32(a) => write!(f, "{}", a.value(self.idx)),
468 ArrayOrd::Int64(a) => write!(f, "{}", a.value(self.idx)),
469 ArrayOrd::UInt8(a) => write!(f, "{}", a.value(self.idx)),
470 ArrayOrd::UInt16(a) => write!(f, "{}", a.value(self.idx)),
471 ArrayOrd::UInt32(a) => write!(f, "{}", a.value(self.idx)),
472 ArrayOrd::UInt64(a) => write!(f, "{}", a.value(self.idx)),
473 ArrayOrd::Float32(a) => write!(f, "{}", a.value(self.idx)),
474 ArrayOrd::Float64(a) => write!(f, "{}", a.value(self.idx)),
475 ArrayOrd::String(a) => write!(f, "{}", a.value(self.idx)),
476 ArrayOrd::Binary(a) => {
477 for byte in a.value(self.idx) {
478 write!(f, "{:02x}", byte)?;
479 }
480 Ok(())
481 }
482 ArrayOrd::FixedSizeBinary(a) => {
483 for byte in a.value(self.idx) {
484 write!(f, "{:02x}", byte)?;
485 }
486 Ok(())
487 }
488 ArrayOrd::List(_, offsets, nested) => {
489 write!(
490 f,
491 "[{}]",
492 mz_ore::str::separated(", ", list_range(offsets, nested, self.idx))
493 )
494 }
495 ArrayOrd::Struct(_, nested) => write!(
496 f,
497 "{{{}}}",
498 mz_ore::str::separated(", ", nested.iter().map(|f| f.at(self.idx)))
499 ),
500 }
501 }
502}
503
504#[inline]
505fn list_range<'a>(
506 offsets: &OffsetBuffer<i32>,
507 values: &'a ArrayOrd,
508 idx: usize,
509) -> impl Iterator<Item = ArrayIdx<'a>> + Clone {
510 let offsets = offsets.inner();
511 let from = offsets[idx].as_usize();
512 let to = offsets[idx + 1].as_usize();
513 (from..to).map(|i| values.at(i))
514}
515
516impl<'a> ArrayIdx<'a> {
517 pub fn goodbytes(&self) -> usize {
520 match self.array {
521 ArrayOrd::Null(_) => 0,
522 ArrayOrd::Bool(_) => size_of::<bool>(),
523 ArrayOrd::Int8(_) => size_of::<i8>(),
524 ArrayOrd::Int16(_) => size_of::<i16>(),
525 ArrayOrd::Int32(_) => size_of::<i32>(),
526 ArrayOrd::Int64(_) => size_of::<i64>(),
527 ArrayOrd::UInt8(_) => size_of::<u8>(),
528 ArrayOrd::UInt16(_) => size_of::<u16>(),
529 ArrayOrd::UInt32(_) => size_of::<u32>(),
530 ArrayOrd::UInt64(_) => size_of::<u64>(),
531 ArrayOrd::Float32(_) => size_of::<f32>(),
532 ArrayOrd::Float64(_) => size_of::<f64>(),
533 ArrayOrd::String(a) => a.value(self.idx).len(),
534 ArrayOrd::Binary(a) => a.value(self.idx).len(),
535 ArrayOrd::FixedSizeBinary(a) => a.value_length().as_usize(),
536 ArrayOrd::List(_, offsets, nested) => {
537 list_range(offsets, nested, self.idx)
539 .map(|a| a.goodbytes())
540 .sum()
541 }
542 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.at(self.idx).goodbytes()).sum(),
543 }
544 }
545}
546
547impl<'a> Ord for ArrayIdx<'a> {
548 fn cmp(&self, other: &Self) -> Ordering {
549 #[inline]
550 fn is_null(buffer: &Option<NullBuffer>, idx: usize) -> bool {
551 buffer.as_ref().map_or(false, |b| b.is_null(idx))
552 }
553 #[inline]
554 fn cmp<A: ArrayAccessor>(
555 left: A,
556 left_idx: usize,
557 right: A,
558 right_idx: usize,
559 cmp: fn(&A::Item, &A::Item) -> Ordering,
560 ) -> Ordering {
561 match (left.is_null(left_idx), right.is_null(right_idx)) {
563 (false, true) => Ordering::Less,
564 (true, true) => Ordering::Equal,
565 (true, false) => Ordering::Greater,
566 (false, false) => cmp(&left.value(left_idx), &right.value(right_idx)),
567 }
568 }
569 match (&self.array, &other.array) {
570 (ArrayOrd::Null(s), ArrayOrd::Null(o)) => {
571 debug_assert!(
572 self.idx < s.len() && other.idx < o.len(),
573 "null array indices in bounds"
574 );
575 Ordering::Equal
576 }
577 (ArrayOrd::Bool(s), ArrayOrd::Bool(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
579 (ArrayOrd::Int8(s), ArrayOrd::Int8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
580 (ArrayOrd::Int16(s), ArrayOrd::Int16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
581 (ArrayOrd::Int32(s), ArrayOrd::Int32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
582 (ArrayOrd::Int64(s), ArrayOrd::Int64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
583 (ArrayOrd::UInt8(s), ArrayOrd::UInt8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
584 (ArrayOrd::UInt16(s), ArrayOrd::UInt16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
585 (ArrayOrd::UInt32(s), ArrayOrd::UInt32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
586 (ArrayOrd::UInt64(s), ArrayOrd::UInt64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
587 (ArrayOrd::Float32(s), ArrayOrd::Float32(o)) => {
588 cmp(s, self.idx, o, other.idx, f32::total_cmp)
589 }
590 (ArrayOrd::Float64(s), ArrayOrd::Float64(o)) => {
591 cmp(s, self.idx, o, other.idx, f64::total_cmp)
592 }
593 (ArrayOrd::String(s), ArrayOrd::String(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
594 (ArrayOrd::Binary(s), ArrayOrd::Binary(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
595 (ArrayOrd::FixedSizeBinary(s), ArrayOrd::FixedSizeBinary(o)) => {
596 cmp(s, self.idx, o, other.idx, Ord::cmp)
597 }
598 (
601 ArrayOrd::List(s_nulls, s_offset, s_values),
602 ArrayOrd::List(o_nulls, o_offset, o_values),
603 ) => match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
604 (false, true) => Ordering::Less,
605 (true, true) => Ordering::Equal,
606 (true, false) => Ordering::Greater,
607 (false, false) => list_range(s_offset, s_values, self.idx)
608 .cmp(list_range(o_offset, o_values, other.idx)),
609 },
610 (ArrayOrd::Struct(s_nulls, s_cols), ArrayOrd::Struct(o_nulls, o_cols)) => {
613 match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
614 (false, true) => Ordering::Less,
615 (true, true) => Ordering::Equal,
616 (true, false) => Ordering::Greater,
617 (false, false) => {
618 let s = s_cols.iter().map(|array| array.at(self.idx));
619 let o = o_cols.iter().map(|array| array.at(other.idx));
620 s.cmp(o)
621 }
622 }
623 }
624 (_, _) => panic!("array types did not match"),
625 }
626 }
627}
628
629impl<'a> PartialOrd for ArrayIdx<'a> {
630 fn partial_cmp(&self, other: &ArrayIdx) -> Option<Ordering> {
631 Some(self.cmp(other))
632 }
633}
634
635impl<'a> PartialEq for ArrayIdx<'a> {
636 fn eq(&self, other: &ArrayIdx) -> bool {
637 self.cmp(other) == Ordering::Equal
638 }
639}
640
641impl<'a> Eq for ArrayIdx<'a> {}
642
643#[derive(Debug, Clone)]
645pub struct ArrayBound {
646 raw: ArrayRef,
647 ord: ArrayOrd,
648 index: usize,
649}
650
651impl PartialEq for ArrayBound {
652 fn eq(&self, other: &Self) -> bool {
653 self.get().eq(&other.get())
654 }
655}
656
657impl Eq for ArrayBound {}
658
659impl ArrayBound {
660 pub fn new(array: ArrayRef, index: usize) -> Self {
662 Self {
663 ord: ArrayOrd::new(array.as_ref()),
664 raw: array,
665 index,
666 }
667 }
668
669 pub fn get(&self) -> ArrayIdx<'_> {
671 self.ord.at(self.index)
672 }
673
674 pub fn to_proto_lower(&self, max_len: usize) -> Option<ProtoArrayData> {
678 let indices = UInt64Array::from_value(u64::usize_as(self.index), 1);
681 let taken = arrow::compute::take(self.raw.as_ref(), &indices, None).ok()?;
682 let array_data = taken.into_data();
683
684 let mut proto = array_data.into_proto();
685 let original_len = proto.encoded_len();
686 if original_len <= max_len {
687 return Some(proto);
688 }
689
690 let mut data_type = proto.data_type.take()?;
691 maybe_trim_proto(&mut data_type, &mut proto, max_len);
692 proto.data_type = Some(data_type);
693
694 if cfg!(debug_assertions) {
695 let array: ArrayData = proto
696 .clone()
697 .into_rust()
698 .expect("trimmed array data can still be decoded");
699 assert_eq!(array.len(), 1);
700 let new_bound = Self::new(make_array(array), 0);
701 assert!(
702 new_bound.get() <= self.get(),
703 "trimmed bound should be comparable to and no larger than the original data"
704 )
705 }
706
707 if proto.encoded_len() <= max_len {
708 Some(proto)
709 } else {
710 None
711 }
712 }
713}
714
715fn maybe_trim_proto(data_type: &mut proto::DataType, body: &mut ProtoArrayData, max_len: usize) {
723 assert!(body.data_type.is_none(), "expected separate data type");
724 let encoded_len = data_type.encoded_len() + body.encoded_len();
726 match &mut data_type.kind {
727 Some(data_type::Kind::Struct(data_type::Struct { children: fields })) => {
728 let mut struct_len = encoded_len;
730 while struct_len > max_len {
731 let Some(mut child) = body.children.pop() else {
732 break;
733 };
734 let Some(mut field) = fields.pop() else { break };
735
736 struct_len -= field.encoded_len() + child.encoded_len();
737 if let Some(remaining_len) = max_len.checked_sub(struct_len) {
738 let Some(field_type) = field.data_type.as_mut() else {
741 break;
742 };
743 maybe_trim_proto(field_type, &mut child, remaining_len);
744 if field.encoded_len() + child.encoded_len() <= remaining_len {
745 fields.push(field);
746 body.children.push(child);
747 }
748 break;
749 }
750 }
751 }
752 _ => {}
753 };
754}
755
756#[cfg(test)]
757mod tests {
758 use crate::arrow::{ArrayBound, ArrayOrd};
759 use arrow::array::{
760 ArrayRef, AsArray, BooleanArray, StringArray, StructArray, UInt64Array, make_array,
761 };
762 use arrow::datatypes::{DataType, Field, Fields};
763 use mz_ore::assert_none;
764 use mz_proto::ProtoType;
765 use std::sync::Arc;
766
767 #[mz_ore::test]
768 fn trim_proto() {
769 let nested_fields: Fields = vec![Field::new("a", DataType::UInt64, true)].into();
770 let array: ArrayRef = Arc::new(StructArray::new(
771 vec![
772 Field::new("a", DataType::UInt64, true),
773 Field::new("b", DataType::Utf8, true),
774 Field::new_struct("c", nested_fields.clone(), true),
775 ]
776 .into(),
777 vec![
778 Arc::new(UInt64Array::from_iter_values([1])),
779 Arc::new(StringArray::from_iter_values(["large".repeat(50)])),
780 Arc::new(StructArray::new_null(nested_fields, 1)),
781 ],
782 None,
783 ));
784 let bound = ArrayBound::new(array, 0);
785
786 assert_none!(bound.to_proto_lower(0));
787 assert_none!(bound.to_proto_lower(1));
788
789 let proto = bound
790 .to_proto_lower(100)
791 .expect("can fit something in less than 100 bytes");
792 let array = make_array(proto.into_rust().expect("valid proto"));
793 assert_eq!(
794 array.as_struct().column_names().as_slice(),
795 &["a"],
796 "only the first column should fit"
797 );
798
799 let proto = bound
800 .to_proto_lower(1000)
801 .expect("can fit everything in less than 1000 bytes");
802 let array = make_array(proto.into_rust().expect("valid proto"));
803 assert_eq!(
804 array.as_struct().column_names().as_slice(),
805 &["a", "b", "c"],
806 "all columns should fit"
807 )
808 }
809
810 #[mz_ore::test]
811 fn struct_ord() {
812 let prefix = StructArray::new(
813 vec![Field::new("a", DataType::UInt64, true)].into(),
814 vec![Arc::new(UInt64Array::from_iter_values([1, 3, 5]))],
815 None,
816 );
817 let full = StructArray::new(
818 vec![
819 Field::new("a", DataType::UInt64, true),
820 Field::new("b", DataType::Utf8, true),
821 ]
822 .into(),
823 vec![
824 Arc::new(UInt64Array::from_iter_values([2, 3, 4])),
825 Arc::new(StringArray::from_iter_values(["a", "b", "c"])),
826 ],
827 None,
828 );
829 let prefix_ord = ArrayOrd::new(&prefix);
830 let full_ord = ArrayOrd::new(&full);
831
832 assert!(prefix_ord.at(0) < full_ord.at(0), "(1) < (2, 'a')");
835 assert!(prefix_ord.at(1) < full_ord.at(1), "(3) < (3, 'b')");
836 assert!(prefix_ord.at(2) > full_ord.at(2), "(5) < (4, 'c')");
837 }
838
839 #[mz_ore::test]
840 #[should_panic(expected = "array types did not match")]
841 fn struct_ord_incompat() {
842 let string = StructArray::new(
845 vec![
846 Field::new("a", DataType::UInt64, true),
847 Field::new("b", DataType::Utf8, true),
848 ]
849 .into(),
850 vec![
851 Arc::new(UInt64Array::from_iter_values([1])),
852 Arc::new(StringArray::from_iter_values(["a"])),
853 ],
854 None,
855 );
856 let boolean = StructArray::new(
857 vec![
858 Field::new("a", DataType::UInt64, true),
859 Field::new("b", DataType::Boolean, true),
860 ]
861 .into(),
862 vec![
863 Arc::new(UInt64Array::from_iter_values([1])),
864 Arc::new(BooleanArray::from_iter([Some(true)])),
865 ],
866 None,
867 );
868 let string_ord = ArrayOrd::new(&string);
869 let bool_ord = ArrayOrd::new(&boolean);
870
871 assert!(string_ord.at(0) < bool_ord.at(0));
873 }
874}