1use std::cmp::Ordering;
26use std::fmt::{Debug, Display};
27use std::sync::Arc;
28
29use arrow::array::*;
30use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
31use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Fields};
32use itertools::Itertools;
33use mz_ore::cast::CastFrom;
34use mz_ore::soft_assert_eq_no_log;
35use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
36use prost::Message;
37
38#[allow(missing_docs)]
39mod proto {
40 include!(concat!(env!("OUT_DIR"), "/mz_persist_types.arrow.rs"));
41}
42use crate::arrow::proto::data_type;
43pub use proto::{DataType as ProtoDataType, ProtoArrayData};
44
45pub fn fields_for_type(data_type: &DataType) -> &[FieldRef] {
47 match data_type {
48 DataType::Struct(fields) => fields,
49 DataType::List(field) => std::slice::from_ref(field),
50 DataType::Map(field, _) => std::slice::from_ref(field),
51 DataType::Null
52 | DataType::Boolean
53 | DataType::Int8
54 | DataType::Int16
55 | DataType::Int32
56 | DataType::Int64
57 | DataType::UInt8
58 | DataType::UInt16
59 | DataType::UInt32
60 | DataType::UInt64
61 | DataType::Float16
62 | DataType::Float32
63 | DataType::Float64
64 | DataType::Timestamp(_, _)
65 | DataType::Date32
66 | DataType::Date64
67 | DataType::Time32(_)
68 | DataType::Time64(_)
69 | DataType::Duration(_)
70 | DataType::Interval(_)
71 | DataType::Binary
72 | DataType::FixedSizeBinary(_)
73 | DataType::LargeBinary
74 | DataType::BinaryView
75 | DataType::Utf8
76 | DataType::LargeUtf8
77 | DataType::Utf8View
78 | DataType::Decimal32(_, _)
79 | DataType::Decimal64(_, _)
80 | DataType::Decimal128(_, _)
81 | DataType::Decimal256(_, _) => &[],
82 DataType::ListView(_)
83 | DataType::FixedSizeList(_, _)
84 | DataType::LargeList(_)
85 | DataType::LargeListView(_)
86 | DataType::Union(_, _)
87 | DataType::Dictionary(_, _)
88 | DataType::RunEndEncoded(_, _) => unimplemented!("not supported"),
89 }
90}
91
92fn into_proto_with_type(data: &ArrayData, expected_type: Option<&DataType>) -> ProtoArrayData {
95 let data_type = match expected_type {
96 Some(expected) => {
97 soft_assert_eq_no_log!(
100 expected,
101 data.data_type(),
102 "actual type should match expected type"
103 );
104 None
105 }
106 None => Some(data.data_type().into_proto()),
107 };
108
109 ProtoArrayData {
110 data_type,
111 length: u64::cast_from(data.len()),
112 offset: u64::cast_from(data.offset()),
113 buffers: data.buffers().iter().map(|b| b.into_proto()).collect(),
114 children: data
115 .child_data()
116 .iter()
117 .zip_eq(fields_for_type(
118 expected_type.unwrap_or_else(|| data.data_type()),
119 ))
120 .map(|(child, expect)| into_proto_with_type(child, Some(expect.data_type())))
121 .collect(),
122 nulls: data.nulls().map(|n| n.inner().into_proto()),
123 }
124}
125
126fn from_proto_with_type(
129 proto: ProtoArrayData,
130 expected_type: Option<&DataType>,
131) -> Result<ArrayData, TryFromProtoError> {
132 let ProtoArrayData {
133 data_type,
134 length,
135 offset,
136 buffers,
137 children,
138 nulls,
139 } = proto;
140 let data_type: Option<DataType> = data_type.into_rust()?;
141 let data_type = match (data_type, expected_type) {
142 (Some(data_type), None) => data_type,
143 (Some(data_type), Some(expected_type)) => {
144 soft_assert_eq_no_log!(
147 data_type,
148 *expected_type,
149 "expected type should match actual type"
150 );
151 data_type
152 }
153 (None, Some(expected_type)) => expected_type.clone(),
154 (None, None) => {
155 return Err(TryFromProtoError::MissingField(
156 "ProtoArrayData::data_type".to_string(),
157 ));
158 }
159 };
160 let nulls = nulls
161 .map(|n| n.into_rust())
162 .transpose()?
163 .map(NullBuffer::new);
164
165 let mut builder = ArrayDataBuilder::new(data_type.clone())
166 .len(usize::cast_from(length))
167 .offset(usize::cast_from(offset))
168 .nulls(nulls);
169
170 for b in buffers.into_iter().map(|b| b.into_rust()) {
171 builder = builder.add_buffer(b?);
172 }
173 for c in children
174 .into_iter()
175 .zip_eq(fields_for_type(&data_type))
176 .map(|(c, field)| from_proto_with_type(c, Some(field.data_type())))
177 {
178 builder = builder.add_child_data(c?);
179 }
180
181 builder
183 .align_buffers(true)
184 .build()
185 .map_err(|e| TryFromProtoError::RowConversionError(e.to_string()))
186}
187
188impl RustType<ProtoArrayData> for arrow::array::ArrayData {
189 fn into_proto(&self) -> ProtoArrayData {
190 into_proto_with_type(self, None)
191 }
192
193 fn from_proto(proto: ProtoArrayData) -> Result<Self, TryFromProtoError> {
194 from_proto_with_type(proto, None)
195 }
196}
197
198impl RustType<proto::DataType> for arrow::datatypes::DataType {
199 fn into_proto(&self) -> proto::DataType {
200 let kind = match self {
201 DataType::Null => proto::data_type::Kind::Null(()),
202 DataType::Boolean => proto::data_type::Kind::Boolean(()),
203 DataType::UInt8 => proto::data_type::Kind::Uint8(()),
204 DataType::UInt16 => proto::data_type::Kind::Uint16(()),
205 DataType::UInt32 => proto::data_type::Kind::Uint32(()),
206 DataType::UInt64 => proto::data_type::Kind::Uint64(()),
207 DataType::Int8 => proto::data_type::Kind::Int8(()),
208 DataType::Int16 => proto::data_type::Kind::Int16(()),
209 DataType::Int32 => proto::data_type::Kind::Int32(()),
210 DataType::Int64 => proto::data_type::Kind::Int64(()),
211 DataType::Float32 => proto::data_type::Kind::Float32(()),
212 DataType::Float64 => proto::data_type::Kind::Float64(()),
213 DataType::Utf8 => proto::data_type::Kind::String(()),
214 DataType::Binary => proto::data_type::Kind::Binary(()),
215 DataType::FixedSizeBinary(size) => proto::data_type::Kind::FixedBinary(*size),
216 DataType::List(inner) => proto::data_type::Kind::List(Box::new(inner.into_proto())),
217 DataType::Map(inner, sorted) => {
218 let map = proto::data_type::Map {
219 value: Some(Box::new(inner.into_proto())),
220 sorted: *sorted,
221 };
222 proto::data_type::Kind::Map(Box::new(map))
223 }
224 DataType::Struct(children) => {
225 let children = children.into_iter().map(|f| f.into_proto()).collect();
226 proto::data_type::Kind::Struct(proto::data_type::Struct { children })
227 }
228 other => unimplemented!("unsupported data type {other:?}"),
229 };
230
231 proto::DataType { kind: Some(kind) }
232 }
233
234 fn from_proto(proto: proto::DataType) -> Result<Self, TryFromProtoError> {
235 let data_type = proto
236 .kind
237 .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
238 let data_type = match data_type {
239 proto::data_type::Kind::Null(()) => DataType::Null,
240 proto::data_type::Kind::Boolean(()) => DataType::Boolean,
241 proto::data_type::Kind::Uint8(()) => DataType::UInt8,
242 proto::data_type::Kind::Uint16(()) => DataType::UInt16,
243 proto::data_type::Kind::Uint32(()) => DataType::UInt32,
244 proto::data_type::Kind::Uint64(()) => DataType::UInt64,
245 proto::data_type::Kind::Int8(()) => DataType::Int8,
246 proto::data_type::Kind::Int16(()) => DataType::Int16,
247 proto::data_type::Kind::Int32(()) => DataType::Int32,
248 proto::data_type::Kind::Int64(()) => DataType::Int64,
249 proto::data_type::Kind::Float32(()) => DataType::Float32,
250 proto::data_type::Kind::Float64(()) => DataType::Float64,
251 proto::data_type::Kind::String(()) => DataType::Utf8,
252 proto::data_type::Kind::Binary(()) => DataType::Binary,
253 proto::data_type::Kind::FixedBinary(size) => DataType::FixedSizeBinary(size),
254 proto::data_type::Kind::List(inner) => DataType::List(Arc::new((*inner).into_rust()?)),
255 proto::data_type::Kind::Map(inner) => {
256 let value = inner
257 .value
258 .ok_or_else(|| TryFromProtoError::missing_field("map.value"))?;
259 DataType::Map(Arc::new((*value).into_rust()?), inner.sorted)
260 }
261 proto::data_type::Kind::Struct(inner) => {
262 let children: Vec<Field> = inner
263 .children
264 .into_iter()
265 .map(|c| c.into_rust())
266 .collect::<Result<_, _>>()?;
267 DataType::Struct(Fields::from(children))
268 }
269 };
270
271 Ok(data_type)
272 }
273}
274
275impl RustType<proto::Field> for arrow::datatypes::Field {
276 fn into_proto(&self) -> proto::Field {
277 proto::Field {
278 name: self.name().clone(),
279 nullable: self.is_nullable(),
280 data_type: Some(Box::new(self.data_type().into_proto())),
281 }
282 }
283
284 fn from_proto(proto: proto::Field) -> Result<Self, TryFromProtoError> {
285 let proto::Field {
286 name,
287 nullable,
288 data_type,
289 } = proto;
290 let data_type =
291 data_type.ok_or_else(|| TryFromProtoError::missing_field("field.data_type"))?;
292 let data_type = (*data_type).into_rust()?;
293
294 Ok(Field::new(name, data_type, nullable))
295 }
296}
297
298impl RustType<proto::Buffer> for arrow::buffer::Buffer {
299 fn into_proto(&self) -> proto::Buffer {
300 #[repr(transparent)]
302 struct BufferWrapper(arrow::buffer::Buffer);
303 impl AsRef<[u8]> for BufferWrapper {
304 fn as_ref(&self) -> &[u8] {
305 &*self.0
306 }
307 }
308 proto::Buffer {
309 data: bytes::Bytes::from_owner(BufferWrapper(self.clone())),
310 }
311 }
312
313 fn from_proto(proto: proto::Buffer) -> Result<Self, TryFromProtoError> {
314 Ok(arrow::buffer::Buffer::from(proto.data))
315 }
316}
317
318impl RustType<proto::BooleanBuffer> for arrow::buffer::BooleanBuffer {
319 fn into_proto(&self) -> proto::BooleanBuffer {
320 proto::BooleanBuffer {
321 buffer: Some(self.sliced().into_proto()),
322 length: u64::cast_from(self.len()),
323 }
324 }
325
326 fn from_proto(proto: proto::BooleanBuffer) -> Result<Self, TryFromProtoError> {
327 let proto::BooleanBuffer { buffer, length } = proto;
328 let buffer = buffer.into_rust_if_some("buffer")?;
329 Ok(BooleanBuffer::new(buffer, 0, usize::cast_from(length)))
330 }
331}
332
333#[derive(Clone)]
335pub enum ArrayOrd {
336 Null(NullArray),
338 Bool(BooleanArray),
340 Int8(Int8Array),
342 Int16(Int16Array),
344 Int32(Int32Array),
346 Int64(Int64Array),
348 UInt8(UInt8Array),
350 UInt16(UInt16Array),
352 UInt32(UInt32Array),
354 UInt64(UInt64Array),
356 Float32(Float32Array),
358 Float64(Float64Array),
360 String(StringArray),
362 Binary(BinaryArray),
364 FixedSizeBinary(FixedSizeBinaryArray),
366 List(Option<NullBuffer>, OffsetBuffer<i32>, Box<ArrayOrd>),
368 Struct(Option<NullBuffer>, Vec<ArrayOrd>),
370}
371
372impl ArrayOrd {
373 pub fn new(array: &dyn Array) -> Self {
375 match array.data_type() {
376 DataType::Null => ArrayOrd::Null(NullArray::from(array.to_data())),
377 DataType::Boolean => ArrayOrd::Bool(array.as_boolean().clone()),
378 DataType::Int8 => ArrayOrd::Int8(array.as_primitive().clone()),
379 DataType::Int16 => ArrayOrd::Int16(array.as_primitive().clone()),
380 DataType::Int32 => ArrayOrd::Int32(array.as_primitive().clone()),
381 DataType::Int64 => ArrayOrd::Int64(array.as_primitive().clone()),
382 DataType::UInt8 => ArrayOrd::UInt8(array.as_primitive().clone()),
383 DataType::UInt16 => ArrayOrd::UInt16(array.as_primitive().clone()),
384 DataType::UInt32 => ArrayOrd::UInt32(array.as_primitive().clone()),
385 DataType::UInt64 => ArrayOrd::UInt64(array.as_primitive().clone()),
386 DataType::Float32 => ArrayOrd::Float32(array.as_primitive().clone()),
387 DataType::Float64 => ArrayOrd::Float64(array.as_primitive().clone()),
388 DataType::Binary => ArrayOrd::Binary(array.as_binary().clone()),
389 DataType::Utf8 => ArrayOrd::String(array.as_string().clone()),
390 DataType::FixedSizeBinary(_) => {
391 ArrayOrd::FixedSizeBinary(array.as_fixed_size_binary().clone())
392 }
393 DataType::List(_) => {
394 let list_array = array.as_list();
395 ArrayOrd::List(
396 list_array.nulls().cloned(),
397 list_array.offsets().clone(),
398 Box::new(ArrayOrd::new(list_array.values())),
399 )
400 }
401 DataType::Struct(_) => {
402 let struct_array = array.as_struct();
403 let nulls = array.nulls().cloned();
404 let columns: Vec<_> = struct_array
405 .columns()
406 .iter()
407 .map(|a| ArrayOrd::new(a))
408 .collect();
409 ArrayOrd::Struct(nulls, columns)
410 }
411 data_type => unimplemented!("array type {data_type:?} not yet supported"),
412 }
413 }
414
415 pub fn goodbytes(&self) -> usize {
418 match self {
419 ArrayOrd::Null(_) => 0,
420 ArrayOrd::Bool(b) => b.len(),
423 ArrayOrd::Int8(a) => a.values().inner().len(),
424 ArrayOrd::Int16(a) => a.values().inner().len(),
425 ArrayOrd::Int32(a) => a.values().inner().len(),
426 ArrayOrd::Int64(a) => a.values().inner().len(),
427 ArrayOrd::UInt8(a) => a.values().inner().len(),
428 ArrayOrd::UInt16(a) => a.values().inner().len(),
429 ArrayOrd::UInt32(a) => a.values().inner().len(),
430 ArrayOrd::UInt64(a) => a.values().inner().len(),
431 ArrayOrd::Float32(a) => a.values().inner().len(),
432 ArrayOrd::Float64(a) => a.values().inner().len(),
433 ArrayOrd::String(a) => a.values().len(),
434 ArrayOrd::Binary(a) => a.values().len(),
435 ArrayOrd::FixedSizeBinary(a) => a.values().len(),
436 ArrayOrd::List(_, _, nested) => nested.goodbytes(),
437 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.goodbytes()).sum(),
438 }
439 }
440
441 pub fn at(&self, idx: usize) -> ArrayIdx<'_> {
443 ArrayIdx { idx, array: self }
444 }
445}
446
447impl Debug for ArrayOrd {
448 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449 struct DebugType<'a>(&'a ArrayOrd);
450
451 impl Debug for DebugType<'_> {
452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453 match self.0 {
454 ArrayOrd::Null(_) => write!(f, "Null"),
455 ArrayOrd::Bool(_) => write!(f, "Bool"),
456 ArrayOrd::Int8(_) => write!(f, "Int8"),
457 ArrayOrd::Int16(_) => write!(f, "Int16"),
458 ArrayOrd::Int32(_) => write!(f, "Int32"),
459 ArrayOrd::Int64(_) => write!(f, "Int64"),
460 ArrayOrd::UInt8(_) => write!(f, "UInt8"),
461 ArrayOrd::UInt16(_) => write!(f, "UInt16"),
462 ArrayOrd::UInt32(_) => write!(f, "UInt32"),
463 ArrayOrd::UInt64(_) => write!(f, "UInt64"),
464 ArrayOrd::Float32(_) => write!(f, "Float32"),
465 ArrayOrd::Float64(_) => write!(f, "Float64"),
466 ArrayOrd::String(_) => write!(f, "String"),
467 ArrayOrd::Binary(_) => write!(f, "Binary"),
468 ArrayOrd::FixedSizeBinary(a) => f
469 .debug_tuple("FixedSizeBinary")
470 .field(&a.value_length())
471 .finish(),
472 ArrayOrd::List(_, _, nested) => f.debug_tuple("List").field(&*nested).finish(),
473 ArrayOrd::Struct(_, fields) => {
474 let mut tuple = f.debug_tuple("Struct");
475 for field in fields {
476 tuple.field(field);
477 }
478 tuple.finish()
479 }
480 }
481 }
482 }
483
484 f.debug_struct("ArrayOrd")
485 .field("type", &DebugType(self))
486 .field("goodbytes", &self.goodbytes())
487 .finish()
488 }
489}
490
491#[derive(Clone, Copy, Debug)]
499pub struct ArrayIdx<'a> {
500 pub idx: usize,
502 pub array: &'a ArrayOrd,
504}
505
506impl Display for ArrayIdx<'_> {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 match self.array {
509 ArrayOrd::Null(_) => write!(f, "null"),
510 ArrayOrd::Bool(a) => write!(f, "{}", a.value(self.idx)),
511 ArrayOrd::Int8(a) => write!(f, "{}", a.value(self.idx)),
512 ArrayOrd::Int16(a) => write!(f, "{}", a.value(self.idx)),
513 ArrayOrd::Int32(a) => write!(f, "{}", a.value(self.idx)),
514 ArrayOrd::Int64(a) => write!(f, "{}", a.value(self.idx)),
515 ArrayOrd::UInt8(a) => write!(f, "{}", a.value(self.idx)),
516 ArrayOrd::UInt16(a) => write!(f, "{}", a.value(self.idx)),
517 ArrayOrd::UInt32(a) => write!(f, "{}", a.value(self.idx)),
518 ArrayOrd::UInt64(a) => write!(f, "{}", a.value(self.idx)),
519 ArrayOrd::Float32(a) => write!(f, "{}", a.value(self.idx)),
520 ArrayOrd::Float64(a) => write!(f, "{}", a.value(self.idx)),
521 ArrayOrd::String(a) => write!(f, "{}", a.value(self.idx)),
522 ArrayOrd::Binary(a) => {
523 for byte in a.value(self.idx) {
524 write!(f, "{:02x}", byte)?;
525 }
526 Ok(())
527 }
528 ArrayOrd::FixedSizeBinary(a) => {
529 for byte in a.value(self.idx) {
530 write!(f, "{:02x}", byte)?;
531 }
532 Ok(())
533 }
534 ArrayOrd::List(_, offsets, nested) => {
535 write!(
536 f,
537 "[{}]",
538 mz_ore::str::separated(", ", list_range(offsets, nested, self.idx))
539 )
540 }
541 ArrayOrd::Struct(_, nested) => write!(
542 f,
543 "{{{}}}",
544 mz_ore::str::separated(", ", nested.iter().map(|f| f.at(self.idx)))
545 ),
546 }
547 }
548}
549
550#[inline]
551fn list_range<'a>(
552 offsets: &OffsetBuffer<i32>,
553 values: &'a ArrayOrd,
554 idx: usize,
555) -> impl Iterator<Item = ArrayIdx<'a>> + Clone {
556 let offsets = offsets.inner();
557 let from = offsets[idx].as_usize();
558 let to = offsets[idx + 1].as_usize();
559 (from..to).map(|i| values.at(i))
560}
561
562impl<'a> ArrayIdx<'a> {
563 pub fn goodbytes(&self) -> usize {
566 match self.array {
567 ArrayOrd::Null(_) => 0,
568 ArrayOrd::Bool(_) => size_of::<bool>(),
569 ArrayOrd::Int8(_) => size_of::<i8>(),
570 ArrayOrd::Int16(_) => size_of::<i16>(),
571 ArrayOrd::Int32(_) => size_of::<i32>(),
572 ArrayOrd::Int64(_) => size_of::<i64>(),
573 ArrayOrd::UInt8(_) => size_of::<u8>(),
574 ArrayOrd::UInt16(_) => size_of::<u16>(),
575 ArrayOrd::UInt32(_) => size_of::<u32>(),
576 ArrayOrd::UInt64(_) => size_of::<u64>(),
577 ArrayOrd::Float32(_) => size_of::<f32>(),
578 ArrayOrd::Float64(_) => size_of::<f64>(),
579 ArrayOrd::String(a) => a.value(self.idx).len(),
580 ArrayOrd::Binary(a) => a.value(self.idx).len(),
581 ArrayOrd::FixedSizeBinary(a) => a.value_length().as_usize(),
582 ArrayOrd::List(_, offsets, nested) => {
583 list_range(offsets, nested, self.idx)
585 .map(|a| a.goodbytes())
586 .sum()
587 }
588 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.at(self.idx).goodbytes()).sum(),
589 }
590 }
591}
592
593impl<'a> Ord for ArrayIdx<'a> {
594 fn cmp(&self, other: &Self) -> Ordering {
595 #[inline]
596 fn is_null(buffer: &Option<NullBuffer>, idx: usize) -> bool {
597 buffer.as_ref().map_or(false, |b| b.is_null(idx))
598 }
599 #[inline]
600 fn cmp<A: ArrayAccessor>(
601 left: A,
602 left_idx: usize,
603 right: A,
604 right_idx: usize,
605 cmp: fn(&A::Item, &A::Item) -> Ordering,
606 ) -> Ordering {
607 match (left.is_null(left_idx), right.is_null(right_idx)) {
609 (false, true) => Ordering::Less,
610 (true, true) => Ordering::Equal,
611 (true, false) => Ordering::Greater,
612 (false, false) => cmp(&left.value(left_idx), &right.value(right_idx)),
613 }
614 }
615 match (&self.array, &other.array) {
616 (ArrayOrd::Null(s), ArrayOrd::Null(o)) => {
617 debug_assert!(
618 self.idx < s.len() && other.idx < o.len(),
619 "null array indices in bounds"
620 );
621 Ordering::Equal
622 }
623 (ArrayOrd::Bool(s), ArrayOrd::Bool(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
625 (ArrayOrd::Int8(s), ArrayOrd::Int8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
626 (ArrayOrd::Int16(s), ArrayOrd::Int16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
627 (ArrayOrd::Int32(s), ArrayOrd::Int32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
628 (ArrayOrd::Int64(s), ArrayOrd::Int64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
629 (ArrayOrd::UInt8(s), ArrayOrd::UInt8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
630 (ArrayOrd::UInt16(s), ArrayOrd::UInt16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
631 (ArrayOrd::UInt32(s), ArrayOrd::UInt32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
632 (ArrayOrd::UInt64(s), ArrayOrd::UInt64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
633 (ArrayOrd::Float32(s), ArrayOrd::Float32(o)) => {
634 cmp(s, self.idx, o, other.idx, f32::total_cmp)
635 }
636 (ArrayOrd::Float64(s), ArrayOrd::Float64(o)) => {
637 cmp(s, self.idx, o, other.idx, f64::total_cmp)
638 }
639 (ArrayOrd::String(s), ArrayOrd::String(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
640 (ArrayOrd::Binary(s), ArrayOrd::Binary(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
641 (ArrayOrd::FixedSizeBinary(s), ArrayOrd::FixedSizeBinary(o)) => {
642 cmp(s, self.idx, o, other.idx, Ord::cmp)
643 }
644 (
647 ArrayOrd::List(s_nulls, s_offset, s_values),
648 ArrayOrd::List(o_nulls, o_offset, o_values),
649 ) => match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
650 (false, true) => Ordering::Less,
651 (true, true) => Ordering::Equal,
652 (true, false) => Ordering::Greater,
653 (false, false) => list_range(s_offset, s_values, self.idx)
654 .cmp(list_range(o_offset, o_values, other.idx)),
655 },
656 (ArrayOrd::Struct(s_nulls, s_cols), ArrayOrd::Struct(o_nulls, o_cols)) => {
659 match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
660 (false, true) => Ordering::Less,
661 (true, true) => Ordering::Equal,
662 (true, false) => Ordering::Greater,
663 (false, false) => {
664 let s = s_cols.iter().map(|array| array.at(self.idx));
665 let o = o_cols.iter().map(|array| array.at(other.idx));
666 s.cmp(o)
667 }
668 }
669 }
670 (a, b) => panic!("array types did not match! {a:?} vs. {b:?}",),
671 }
672 }
673}
674
675impl<'a> PartialOrd for ArrayIdx<'a> {
676 fn partial_cmp(&self, other: &ArrayIdx) -> Option<Ordering> {
677 Some(self.cmp(other))
678 }
679}
680
681impl<'a> PartialEq for ArrayIdx<'a> {
682 fn eq(&self, other: &ArrayIdx) -> bool {
683 self.cmp(other) == Ordering::Equal
684 }
685}
686
687impl<'a> Eq for ArrayIdx<'a> {}
688
689#[derive(Debug, Clone)]
691pub struct ArrayBound {
692 raw: ArrayRef,
693 ord: ArrayOrd,
694 index: usize,
695}
696
697impl PartialEq for ArrayBound {
698 fn eq(&self, other: &Self) -> bool {
699 self.get().eq(&other.get())
700 }
701}
702
703impl Eq for ArrayBound {}
704
705impl ArrayBound {
706 pub fn new(array: ArrayRef, index: usize) -> Self {
708 Self {
709 ord: ArrayOrd::new(array.as_ref()),
710 raw: array,
711 index,
712 }
713 }
714
715 pub fn get(&self) -> ArrayIdx<'_> {
717 self.ord.at(self.index)
718 }
719
720 pub fn to_proto_lower(&self, max_len: usize) -> Option<ProtoArrayData> {
724 let indices = UInt64Array::from_value(u64::usize_as(self.index), 1);
727 let taken = arrow::compute::take(self.raw.as_ref(), &indices, None).ok()?;
728 let array_data = taken.into_data();
729
730 let mut proto = array_data.into_proto();
731 let original_len = proto.encoded_len();
732 if original_len <= max_len {
733 return Some(proto);
734 }
735
736 let mut data_type = proto.data_type.take()?;
737 maybe_trim_proto(&mut data_type, &mut proto, max_len);
738 proto.data_type = Some(data_type);
739
740 if cfg!(debug_assertions) {
741 let array: ArrayData = proto
742 .clone()
743 .into_rust()
744 .expect("trimmed array data can still be decoded");
745 assert_eq!(array.len(), 1);
746 let new_bound = Self::new(make_array(array), 0);
747 assert!(
748 new_bound.get() <= self.get(),
749 "trimmed bound should be comparable to and no larger than the original data"
750 )
751 }
752
753 if proto.encoded_len() <= max_len {
754 Some(proto)
755 } else {
756 None
757 }
758 }
759}
760
761fn maybe_trim_proto(data_type: &mut proto::DataType, body: &mut ProtoArrayData, max_len: usize) {
769 assert!(body.data_type.is_none(), "expected separate data type");
770 let encoded_len = data_type.encoded_len() + body.encoded_len();
772 match &mut data_type.kind {
773 Some(data_type::Kind::Struct(data_type::Struct { children: fields })) => {
774 let mut struct_len = encoded_len;
776 while struct_len > max_len {
777 let Some(mut child) = body.children.pop() else {
778 break;
779 };
780 let Some(mut field) = fields.pop() else { break };
781
782 struct_len -= field.encoded_len() + child.encoded_len();
783 if let Some(remaining_len) = max_len.checked_sub(struct_len) {
784 let Some(field_type) = field.data_type.as_mut() else {
787 break;
788 };
789 maybe_trim_proto(field_type, &mut child, remaining_len);
790 if field.encoded_len() + child.encoded_len() <= remaining_len {
791 fields.push(field);
792 body.children.push(child);
793 }
794 break;
795 }
796 }
797 }
798 _ => {}
799 };
800}
801
802#[cfg(test)]
803mod tests {
804 use crate::arrow::{ArrayBound, ArrayOrd};
805 use arrow::array::{
806 ArrayRef, AsArray, BooleanArray, StringArray, StructArray, UInt64Array, make_array,
807 };
808 use arrow::datatypes::{DataType, Field, Fields};
809 use mz_ore::assert_none;
810 use mz_proto::ProtoType;
811 use std::sync::Arc;
812
813 #[mz_ore::test]
814 fn trim_proto() {
815 let nested_fields: Fields = vec![Field::new("a", DataType::UInt64, true)].into();
816 let array: ArrayRef = Arc::new(StructArray::new(
817 vec![
818 Field::new("a", DataType::UInt64, true),
819 Field::new("b", DataType::Utf8, true),
820 Field::new_struct("c", nested_fields.clone(), true),
821 ]
822 .into(),
823 vec![
824 Arc::new(UInt64Array::from_iter_values([1])),
825 Arc::new(StringArray::from_iter_values(["large".repeat(50)])),
826 Arc::new(StructArray::new_null(nested_fields, 1)),
827 ],
828 None,
829 ));
830 let bound = ArrayBound::new(array, 0);
831
832 assert_none!(bound.to_proto_lower(0));
833 assert_none!(bound.to_proto_lower(1));
834
835 let proto = bound
836 .to_proto_lower(100)
837 .expect("can fit something in less than 100 bytes");
838 let array = make_array(proto.into_rust().expect("valid proto"));
839 assert_eq!(
840 array.as_struct().column_names().as_slice(),
841 &["a"],
842 "only the first column should fit"
843 );
844
845 let proto = bound
846 .to_proto_lower(1000)
847 .expect("can fit everything in less than 1000 bytes");
848 let array = make_array(proto.into_rust().expect("valid proto"));
849 assert_eq!(
850 array.as_struct().column_names().as_slice(),
851 &["a", "b", "c"],
852 "all columns should fit"
853 )
854 }
855
856 #[mz_ore::test]
857 fn struct_ord() {
858 let prefix = StructArray::new(
859 vec![Field::new("a", DataType::UInt64, true)].into(),
860 vec![Arc::new(UInt64Array::from_iter_values([1, 3, 5]))],
861 None,
862 );
863 let full = StructArray::new(
864 vec![
865 Field::new("a", DataType::UInt64, true),
866 Field::new("b", DataType::Utf8, true),
867 ]
868 .into(),
869 vec![
870 Arc::new(UInt64Array::from_iter_values([2, 3, 4])),
871 Arc::new(StringArray::from_iter_values(["a", "b", "c"])),
872 ],
873 None,
874 );
875 let prefix_ord = ArrayOrd::new(&prefix);
876 let full_ord = ArrayOrd::new(&full);
877
878 assert!(prefix_ord.at(0) < full_ord.at(0), "(1) < (2, 'a')");
881 assert!(prefix_ord.at(1) < full_ord.at(1), "(3) < (3, 'b')");
882 assert!(prefix_ord.at(2) > full_ord.at(2), "(5) < (4, 'c')");
883 }
884
885 #[mz_ore::test]
886 #[should_panic(expected = "array types did not match")]
887 fn struct_ord_incompat() {
888 let string = StructArray::new(
891 vec![
892 Field::new("a", DataType::UInt64, true),
893 Field::new("b", DataType::Utf8, true),
894 ]
895 .into(),
896 vec![
897 Arc::new(UInt64Array::from_iter_values([1])),
898 Arc::new(StringArray::from_iter_values(["a"])),
899 ],
900 None,
901 );
902 let boolean = StructArray::new(
903 vec![
904 Field::new("a", DataType::UInt64, true),
905 Field::new("b", DataType::Boolean, true),
906 ]
907 .into(),
908 vec![
909 Arc::new(UInt64Array::from_iter_values([1])),
910 Arc::new(BooleanArray::from_iter([Some(true)])),
911 ],
912 None,
913 );
914 let string_ord = ArrayOrd::new(&string);
915 let bool_ord = ArrayOrd::new(&boolean);
916
917 assert!(string_ord.at(0) < bool_ord.at(0));
919 }
920}