1use std::cmp::Ordering;
26use std::fmt::{Debug, Display};
27use std::sync::Arc;
28
29use arrow::array::*;
30use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
31use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Fields};
32use itertools::Itertools;
33use mz_ore::cast::CastFrom;
34use mz_ore::soft_assert_eq_no_log;
35use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
36use prost::Message;
37
38#[allow(missing_docs)]
39mod proto {
40 include!(concat!(env!("OUT_DIR"), "/mz_persist_types.arrow.rs"));
41}
42use crate::arrow::proto::data_type;
43pub use proto::{DataType as ProtoDataType, ProtoArrayData};
44
45pub fn fields_for_type(data_type: &DataType) -> &[FieldRef] {
47 match data_type {
48 DataType::Struct(fields) => fields,
49 DataType::List(field) => std::slice::from_ref(field),
50 DataType::Map(field, _) => std::slice::from_ref(field),
51 DataType::Null
52 | DataType::Boolean
53 | DataType::Int8
54 | DataType::Int16
55 | DataType::Int32
56 | DataType::Int64
57 | DataType::UInt8
58 | DataType::UInt16
59 | DataType::UInt32
60 | DataType::UInt64
61 | DataType::Float16
62 | DataType::Float32
63 | DataType::Float64
64 | DataType::Timestamp(_, _)
65 | DataType::Date32
66 | DataType::Date64
67 | DataType::Time32(_)
68 | DataType::Time64(_)
69 | DataType::Duration(_)
70 | DataType::Interval(_)
71 | DataType::Binary
72 | DataType::FixedSizeBinary(_)
73 | DataType::LargeBinary
74 | DataType::BinaryView
75 | DataType::Utf8
76 | DataType::LargeUtf8
77 | DataType::Utf8View
78 | DataType::Decimal128(_, _)
79 | DataType::Decimal256(_, _) => &[],
80 DataType::ListView(_)
81 | DataType::FixedSizeList(_, _)
82 | DataType::LargeList(_)
83 | DataType::LargeListView(_)
84 | DataType::Union(_, _)
85 | DataType::Dictionary(_, _)
86 | DataType::RunEndEncoded(_, _) => unimplemented!("not supported"),
87 }
88}
89
90fn into_proto_with_type(data: &ArrayData, expected_type: Option<&DataType>) -> ProtoArrayData {
93 let data_type = match expected_type {
94 Some(expected) => {
95 soft_assert_eq_no_log!(
98 expected,
99 data.data_type(),
100 "actual type should match expected type"
101 );
102 None
103 }
104 None => Some(data.data_type().into_proto()),
105 };
106
107 ProtoArrayData {
108 data_type,
109 length: u64::cast_from(data.len()),
110 offset: u64::cast_from(data.offset()),
111 buffers: data.buffers().iter().map(|b| b.into_proto()).collect(),
112 children: data
113 .child_data()
114 .iter()
115 .zip_eq(fields_for_type(
116 expected_type.unwrap_or_else(|| data.data_type()),
117 ))
118 .map(|(child, expect)| into_proto_with_type(child, Some(expect.data_type())))
119 .collect(),
120 nulls: data.nulls().map(|n| n.inner().into_proto()),
121 }
122}
123
124fn from_proto_with_type(
127 proto: ProtoArrayData,
128 expected_type: Option<&DataType>,
129) -> Result<ArrayData, TryFromProtoError> {
130 let ProtoArrayData {
131 data_type,
132 length,
133 offset,
134 buffers,
135 children,
136 nulls,
137 } = proto;
138 let data_type: Option<DataType> = data_type.into_rust()?;
139 let data_type = match (data_type, expected_type) {
140 (Some(data_type), None) => data_type,
141 (Some(data_type), Some(expected_type)) => {
142 soft_assert_eq_no_log!(
145 data_type,
146 *expected_type,
147 "expected type should match actual type"
148 );
149 data_type
150 }
151 (None, Some(expected_type)) => expected_type.clone(),
152 (None, None) => {
153 return Err(TryFromProtoError::MissingField(
154 "ProtoArrayData::data_type".to_string(),
155 ));
156 }
157 };
158 let nulls = nulls
159 .map(|n| n.into_rust())
160 .transpose()?
161 .map(NullBuffer::new);
162
163 let mut builder = ArrayDataBuilder::new(data_type.clone())
164 .len(usize::cast_from(length))
165 .offset(usize::cast_from(offset))
166 .nulls(nulls);
167
168 for b in buffers.into_iter().map(|b| b.into_rust()) {
169 builder = builder.add_buffer(b?);
170 }
171 for c in children
172 .into_iter()
173 .zip_eq(fields_for_type(&data_type))
174 .map(|(c, field)| from_proto_with_type(c, Some(field.data_type())))
175 {
176 builder = builder.add_child_data(c?);
177 }
178
179 builder
181 .align_buffers(true)
182 .build()
183 .map_err(|e| TryFromProtoError::RowConversionError(e.to_string()))
184}
185
186impl RustType<ProtoArrayData> for arrow::array::ArrayData {
187 fn into_proto(&self) -> ProtoArrayData {
188 into_proto_with_type(self, None)
189 }
190
191 fn from_proto(proto: ProtoArrayData) -> Result<Self, TryFromProtoError> {
192 from_proto_with_type(proto, None)
193 }
194}
195
196impl RustType<proto::DataType> for arrow::datatypes::DataType {
197 fn into_proto(&self) -> proto::DataType {
198 let kind = match self {
199 DataType::Null => proto::data_type::Kind::Null(()),
200 DataType::Boolean => proto::data_type::Kind::Boolean(()),
201 DataType::UInt8 => proto::data_type::Kind::Uint8(()),
202 DataType::UInt16 => proto::data_type::Kind::Uint16(()),
203 DataType::UInt32 => proto::data_type::Kind::Uint32(()),
204 DataType::UInt64 => proto::data_type::Kind::Uint64(()),
205 DataType::Int8 => proto::data_type::Kind::Int8(()),
206 DataType::Int16 => proto::data_type::Kind::Int16(()),
207 DataType::Int32 => proto::data_type::Kind::Int32(()),
208 DataType::Int64 => proto::data_type::Kind::Int64(()),
209 DataType::Float32 => proto::data_type::Kind::Float32(()),
210 DataType::Float64 => proto::data_type::Kind::Float64(()),
211 DataType::Utf8 => proto::data_type::Kind::String(()),
212 DataType::Binary => proto::data_type::Kind::Binary(()),
213 DataType::FixedSizeBinary(size) => proto::data_type::Kind::FixedBinary(*size),
214 DataType::List(inner) => proto::data_type::Kind::List(Box::new(inner.into_proto())),
215 DataType::Map(inner, sorted) => {
216 let map = proto::data_type::Map {
217 value: Some(Box::new(inner.into_proto())),
218 sorted: *sorted,
219 };
220 proto::data_type::Kind::Map(Box::new(map))
221 }
222 DataType::Struct(children) => {
223 let children = children.into_iter().map(|f| f.into_proto()).collect();
224 proto::data_type::Kind::Struct(proto::data_type::Struct { children })
225 }
226 other => unimplemented!("unsupported data type {other:?}"),
227 };
228
229 proto::DataType { kind: Some(kind) }
230 }
231
232 fn from_proto(proto: proto::DataType) -> Result<Self, TryFromProtoError> {
233 let data_type = proto
234 .kind
235 .ok_or_else(|| TryFromProtoError::missing_field("kind"))?;
236 let data_type = match data_type {
237 proto::data_type::Kind::Null(()) => DataType::Null,
238 proto::data_type::Kind::Boolean(()) => DataType::Boolean,
239 proto::data_type::Kind::Uint8(()) => DataType::UInt8,
240 proto::data_type::Kind::Uint16(()) => DataType::UInt16,
241 proto::data_type::Kind::Uint32(()) => DataType::UInt32,
242 proto::data_type::Kind::Uint64(()) => DataType::UInt64,
243 proto::data_type::Kind::Int8(()) => DataType::Int8,
244 proto::data_type::Kind::Int16(()) => DataType::Int16,
245 proto::data_type::Kind::Int32(()) => DataType::Int32,
246 proto::data_type::Kind::Int64(()) => DataType::Int64,
247 proto::data_type::Kind::Float32(()) => DataType::Float32,
248 proto::data_type::Kind::Float64(()) => DataType::Float64,
249 proto::data_type::Kind::String(()) => DataType::Utf8,
250 proto::data_type::Kind::Binary(()) => DataType::Binary,
251 proto::data_type::Kind::FixedBinary(size) => DataType::FixedSizeBinary(size),
252 proto::data_type::Kind::List(inner) => DataType::List(Arc::new((*inner).into_rust()?)),
253 proto::data_type::Kind::Map(inner) => {
254 let value = inner
255 .value
256 .ok_or_else(|| TryFromProtoError::missing_field("map.value"))?;
257 DataType::Map(Arc::new((*value).into_rust()?), inner.sorted)
258 }
259 proto::data_type::Kind::Struct(inner) => {
260 let children: Vec<Field> = inner
261 .children
262 .into_iter()
263 .map(|c| c.into_rust())
264 .collect::<Result<_, _>>()?;
265 DataType::Struct(Fields::from(children))
266 }
267 };
268
269 Ok(data_type)
270 }
271}
272
273impl RustType<proto::Field> for arrow::datatypes::Field {
274 fn into_proto(&self) -> proto::Field {
275 proto::Field {
276 name: self.name().clone(),
277 nullable: self.is_nullable(),
278 data_type: Some(Box::new(self.data_type().into_proto())),
279 }
280 }
281
282 fn from_proto(proto: proto::Field) -> Result<Self, TryFromProtoError> {
283 let proto::Field {
284 name,
285 nullable,
286 data_type,
287 } = proto;
288 let data_type =
289 data_type.ok_or_else(|| TryFromProtoError::missing_field("field.data_type"))?;
290 let data_type = (*data_type).into_rust()?;
291
292 Ok(Field::new(name, data_type, nullable))
293 }
294}
295
296impl RustType<proto::Buffer> for arrow::buffer::Buffer {
297 fn into_proto(&self) -> proto::Buffer {
298 #[repr(transparent)]
300 struct BufferWrapper(arrow::buffer::Buffer);
301 impl AsRef<[u8]> for BufferWrapper {
302 fn as_ref(&self) -> &[u8] {
303 &*self.0
304 }
305 }
306 proto::Buffer {
307 data: bytes::Bytes::from_owner(BufferWrapper(self.clone())),
308 }
309 }
310
311 fn from_proto(proto: proto::Buffer) -> Result<Self, TryFromProtoError> {
312 Ok(arrow::buffer::Buffer::from(proto.data))
313 }
314}
315
316impl RustType<proto::BooleanBuffer> for arrow::buffer::BooleanBuffer {
317 fn into_proto(&self) -> proto::BooleanBuffer {
318 proto::BooleanBuffer {
319 buffer: Some(self.sliced().into_proto()),
320 length: u64::cast_from(self.len()),
321 }
322 }
323
324 fn from_proto(proto: proto::BooleanBuffer) -> Result<Self, TryFromProtoError> {
325 let proto::BooleanBuffer { buffer, length } = proto;
326 let buffer = buffer.into_rust_if_some("buffer")?;
327 Ok(BooleanBuffer::new(buffer, 0, usize::cast_from(length)))
328 }
329}
330
331#[derive(Clone)]
333pub enum ArrayOrd {
334 Null(NullArray),
336 Bool(BooleanArray),
338 Int8(Int8Array),
340 Int16(Int16Array),
342 Int32(Int32Array),
344 Int64(Int64Array),
346 UInt8(UInt8Array),
348 UInt16(UInt16Array),
350 UInt32(UInt32Array),
352 UInt64(UInt64Array),
354 Float32(Float32Array),
356 Float64(Float64Array),
358 String(StringArray),
360 Binary(BinaryArray),
362 FixedSizeBinary(FixedSizeBinaryArray),
364 List(Option<NullBuffer>, OffsetBuffer<i32>, Box<ArrayOrd>),
366 Struct(Option<NullBuffer>, Vec<ArrayOrd>),
368}
369
370impl ArrayOrd {
371 pub fn new(array: &dyn Array) -> Self {
373 match array.data_type() {
374 DataType::Null => ArrayOrd::Null(NullArray::from(array.to_data())),
375 DataType::Boolean => ArrayOrd::Bool(array.as_boolean().clone()),
376 DataType::Int8 => ArrayOrd::Int8(array.as_primitive().clone()),
377 DataType::Int16 => ArrayOrd::Int16(array.as_primitive().clone()),
378 DataType::Int32 => ArrayOrd::Int32(array.as_primitive().clone()),
379 DataType::Int64 => ArrayOrd::Int64(array.as_primitive().clone()),
380 DataType::UInt8 => ArrayOrd::UInt8(array.as_primitive().clone()),
381 DataType::UInt16 => ArrayOrd::UInt16(array.as_primitive().clone()),
382 DataType::UInt32 => ArrayOrd::UInt32(array.as_primitive().clone()),
383 DataType::UInt64 => ArrayOrd::UInt64(array.as_primitive().clone()),
384 DataType::Float32 => ArrayOrd::Float32(array.as_primitive().clone()),
385 DataType::Float64 => ArrayOrd::Float64(array.as_primitive().clone()),
386 DataType::Binary => ArrayOrd::Binary(array.as_binary().clone()),
387 DataType::Utf8 => ArrayOrd::String(array.as_string().clone()),
388 DataType::FixedSizeBinary(_) => {
389 ArrayOrd::FixedSizeBinary(array.as_fixed_size_binary().clone())
390 }
391 DataType::List(_) => {
392 let list_array = array.as_list();
393 ArrayOrd::List(
394 list_array.nulls().cloned(),
395 list_array.offsets().clone(),
396 Box::new(ArrayOrd::new(list_array.values())),
397 )
398 }
399 DataType::Struct(_) => {
400 let struct_array = array.as_struct();
401 let nulls = array.nulls().cloned();
402 let columns: Vec<_> = struct_array
403 .columns()
404 .iter()
405 .map(|a| ArrayOrd::new(a))
406 .collect();
407 ArrayOrd::Struct(nulls, columns)
408 }
409 data_type => unimplemented!("array type {data_type:?} not yet supported"),
410 }
411 }
412
413 pub fn goodbytes(&self) -> usize {
416 match self {
417 ArrayOrd::Null(_) => 0,
418 ArrayOrd::Bool(b) => b.len(),
421 ArrayOrd::Int8(a) => a.values().inner().len(),
422 ArrayOrd::Int16(a) => a.values().inner().len(),
423 ArrayOrd::Int32(a) => a.values().inner().len(),
424 ArrayOrd::Int64(a) => a.values().inner().len(),
425 ArrayOrd::UInt8(a) => a.values().inner().len(),
426 ArrayOrd::UInt16(a) => a.values().inner().len(),
427 ArrayOrd::UInt32(a) => a.values().inner().len(),
428 ArrayOrd::UInt64(a) => a.values().inner().len(),
429 ArrayOrd::Float32(a) => a.values().inner().len(),
430 ArrayOrd::Float64(a) => a.values().inner().len(),
431 ArrayOrd::String(a) => a.values().len(),
432 ArrayOrd::Binary(a) => a.values().len(),
433 ArrayOrd::FixedSizeBinary(a) => a.values().len(),
434 ArrayOrd::List(_, _, nested) => nested.goodbytes(),
435 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.goodbytes()).sum(),
436 }
437 }
438
439 pub fn at(&self, idx: usize) -> ArrayIdx<'_> {
441 ArrayIdx { idx, array: self }
442 }
443}
444
445impl Debug for ArrayOrd {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 struct DebugType<'a>(&'a ArrayOrd);
448
449 impl Debug for DebugType<'_> {
450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451 match self.0 {
452 ArrayOrd::Null(_) => write!(f, "Null"),
453 ArrayOrd::Bool(_) => write!(f, "Bool"),
454 ArrayOrd::Int8(_) => write!(f, "Int8"),
455 ArrayOrd::Int16(_) => write!(f, "Int16"),
456 ArrayOrd::Int32(_) => write!(f, "Int32"),
457 ArrayOrd::Int64(_) => write!(f, "Int64"),
458 ArrayOrd::UInt8(_) => write!(f, "UInt8"),
459 ArrayOrd::UInt16(_) => write!(f, "UInt16"),
460 ArrayOrd::UInt32(_) => write!(f, "UInt32"),
461 ArrayOrd::UInt64(_) => write!(f, "UInt64"),
462 ArrayOrd::Float32(_) => write!(f, "Float32"),
463 ArrayOrd::Float64(_) => write!(f, "Float64"),
464 ArrayOrd::String(_) => write!(f, "String"),
465 ArrayOrd::Binary(_) => write!(f, "Binary"),
466 ArrayOrd::FixedSizeBinary(a) => f
467 .debug_tuple("FixedSizeBinary")
468 .field(&a.value_length())
469 .finish(),
470 ArrayOrd::List(_, _, nested) => f.debug_tuple("List").field(&*nested).finish(),
471 ArrayOrd::Struct(_, fields) => {
472 let mut tuple = f.debug_tuple("Struct");
473 for field in fields {
474 tuple.field(field);
475 }
476 tuple.finish()
477 }
478 }
479 }
480 }
481
482 f.debug_struct("ArrayOrd")
483 .field("type", &DebugType(self))
484 .field("goodbytes", &self.goodbytes())
485 .finish()
486 }
487}
488
489#[derive(Clone, Copy, Debug)]
497pub struct ArrayIdx<'a> {
498 pub idx: usize,
500 pub array: &'a ArrayOrd,
502}
503
504impl Display for ArrayIdx<'_> {
505 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506 match self.array {
507 ArrayOrd::Null(_) => write!(f, "null"),
508 ArrayOrd::Bool(a) => write!(f, "{}", a.value(self.idx)),
509 ArrayOrd::Int8(a) => write!(f, "{}", a.value(self.idx)),
510 ArrayOrd::Int16(a) => write!(f, "{}", a.value(self.idx)),
511 ArrayOrd::Int32(a) => write!(f, "{}", a.value(self.idx)),
512 ArrayOrd::Int64(a) => write!(f, "{}", a.value(self.idx)),
513 ArrayOrd::UInt8(a) => write!(f, "{}", a.value(self.idx)),
514 ArrayOrd::UInt16(a) => write!(f, "{}", a.value(self.idx)),
515 ArrayOrd::UInt32(a) => write!(f, "{}", a.value(self.idx)),
516 ArrayOrd::UInt64(a) => write!(f, "{}", a.value(self.idx)),
517 ArrayOrd::Float32(a) => write!(f, "{}", a.value(self.idx)),
518 ArrayOrd::Float64(a) => write!(f, "{}", a.value(self.idx)),
519 ArrayOrd::String(a) => write!(f, "{}", a.value(self.idx)),
520 ArrayOrd::Binary(a) => {
521 for byte in a.value(self.idx) {
522 write!(f, "{:02x}", byte)?;
523 }
524 Ok(())
525 }
526 ArrayOrd::FixedSizeBinary(a) => {
527 for byte in a.value(self.idx) {
528 write!(f, "{:02x}", byte)?;
529 }
530 Ok(())
531 }
532 ArrayOrd::List(_, offsets, nested) => {
533 write!(
534 f,
535 "[{}]",
536 mz_ore::str::separated(", ", list_range(offsets, nested, self.idx))
537 )
538 }
539 ArrayOrd::Struct(_, nested) => write!(
540 f,
541 "{{{}}}",
542 mz_ore::str::separated(", ", nested.iter().map(|f| f.at(self.idx)))
543 ),
544 }
545 }
546}
547
548#[inline]
549fn list_range<'a>(
550 offsets: &OffsetBuffer<i32>,
551 values: &'a ArrayOrd,
552 idx: usize,
553) -> impl Iterator<Item = ArrayIdx<'a>> + Clone {
554 let offsets = offsets.inner();
555 let from = offsets[idx].as_usize();
556 let to = offsets[idx + 1].as_usize();
557 (from..to).map(|i| values.at(i))
558}
559
560impl<'a> ArrayIdx<'a> {
561 pub fn goodbytes(&self) -> usize {
564 match self.array {
565 ArrayOrd::Null(_) => 0,
566 ArrayOrd::Bool(_) => size_of::<bool>(),
567 ArrayOrd::Int8(_) => size_of::<i8>(),
568 ArrayOrd::Int16(_) => size_of::<i16>(),
569 ArrayOrd::Int32(_) => size_of::<i32>(),
570 ArrayOrd::Int64(_) => size_of::<i64>(),
571 ArrayOrd::UInt8(_) => size_of::<u8>(),
572 ArrayOrd::UInt16(_) => size_of::<u16>(),
573 ArrayOrd::UInt32(_) => size_of::<u32>(),
574 ArrayOrd::UInt64(_) => size_of::<u64>(),
575 ArrayOrd::Float32(_) => size_of::<f32>(),
576 ArrayOrd::Float64(_) => size_of::<f64>(),
577 ArrayOrd::String(a) => a.value(self.idx).len(),
578 ArrayOrd::Binary(a) => a.value(self.idx).len(),
579 ArrayOrd::FixedSizeBinary(a) => a.value_length().as_usize(),
580 ArrayOrd::List(_, offsets, nested) => {
581 list_range(offsets, nested, self.idx)
583 .map(|a| a.goodbytes())
584 .sum()
585 }
586 ArrayOrd::Struct(_, nested) => nested.iter().map(|a| a.at(self.idx).goodbytes()).sum(),
587 }
588 }
589}
590
591impl<'a> Ord for ArrayIdx<'a> {
592 fn cmp(&self, other: &Self) -> Ordering {
593 #[inline]
594 fn is_null(buffer: &Option<NullBuffer>, idx: usize) -> bool {
595 buffer.as_ref().map_or(false, |b| b.is_null(idx))
596 }
597 #[inline]
598 fn cmp<A: ArrayAccessor>(
599 left: A,
600 left_idx: usize,
601 right: A,
602 right_idx: usize,
603 cmp: fn(&A::Item, &A::Item) -> Ordering,
604 ) -> Ordering {
605 match (left.is_null(left_idx), right.is_null(right_idx)) {
607 (false, true) => Ordering::Less,
608 (true, true) => Ordering::Equal,
609 (true, false) => Ordering::Greater,
610 (false, false) => cmp(&left.value(left_idx), &right.value(right_idx)),
611 }
612 }
613 match (&self.array, &other.array) {
614 (ArrayOrd::Null(s), ArrayOrd::Null(o)) => {
615 debug_assert!(
616 self.idx < s.len() && other.idx < o.len(),
617 "null array indices in bounds"
618 );
619 Ordering::Equal
620 }
621 (ArrayOrd::Bool(s), ArrayOrd::Bool(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
623 (ArrayOrd::Int8(s), ArrayOrd::Int8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
624 (ArrayOrd::Int16(s), ArrayOrd::Int16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
625 (ArrayOrd::Int32(s), ArrayOrd::Int32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
626 (ArrayOrd::Int64(s), ArrayOrd::Int64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
627 (ArrayOrd::UInt8(s), ArrayOrd::UInt8(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
628 (ArrayOrd::UInt16(s), ArrayOrd::UInt16(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
629 (ArrayOrd::UInt32(s), ArrayOrd::UInt32(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
630 (ArrayOrd::UInt64(s), ArrayOrd::UInt64(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
631 (ArrayOrd::Float32(s), ArrayOrd::Float32(o)) => {
632 cmp(s, self.idx, o, other.idx, f32::total_cmp)
633 }
634 (ArrayOrd::Float64(s), ArrayOrd::Float64(o)) => {
635 cmp(s, self.idx, o, other.idx, f64::total_cmp)
636 }
637 (ArrayOrd::String(s), ArrayOrd::String(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
638 (ArrayOrd::Binary(s), ArrayOrd::Binary(o)) => cmp(s, self.idx, o, other.idx, Ord::cmp),
639 (ArrayOrd::FixedSizeBinary(s), ArrayOrd::FixedSizeBinary(o)) => {
640 cmp(s, self.idx, o, other.idx, Ord::cmp)
641 }
642 (
645 ArrayOrd::List(s_nulls, s_offset, s_values),
646 ArrayOrd::List(o_nulls, o_offset, o_values),
647 ) => match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
648 (false, true) => Ordering::Less,
649 (true, true) => Ordering::Equal,
650 (true, false) => Ordering::Greater,
651 (false, false) => list_range(s_offset, s_values, self.idx)
652 .cmp(list_range(o_offset, o_values, other.idx)),
653 },
654 (ArrayOrd::Struct(s_nulls, s_cols), ArrayOrd::Struct(o_nulls, o_cols)) => {
657 match (is_null(s_nulls, self.idx), is_null(o_nulls, other.idx)) {
658 (false, true) => Ordering::Less,
659 (true, true) => Ordering::Equal,
660 (true, false) => Ordering::Greater,
661 (false, false) => {
662 let s = s_cols.iter().map(|array| array.at(self.idx));
663 let o = o_cols.iter().map(|array| array.at(other.idx));
664 s.cmp(o)
665 }
666 }
667 }
668 (a, b) => panic!("array types did not match! {a:?} vs. {b:?}",),
669 }
670 }
671}
672
673impl<'a> PartialOrd for ArrayIdx<'a> {
674 fn partial_cmp(&self, other: &ArrayIdx) -> Option<Ordering> {
675 Some(self.cmp(other))
676 }
677}
678
679impl<'a> PartialEq for ArrayIdx<'a> {
680 fn eq(&self, other: &ArrayIdx) -> bool {
681 self.cmp(other) == Ordering::Equal
682 }
683}
684
685impl<'a> Eq for ArrayIdx<'a> {}
686
687#[derive(Debug, Clone)]
689pub struct ArrayBound {
690 raw: ArrayRef,
691 ord: ArrayOrd,
692 index: usize,
693}
694
695impl PartialEq for ArrayBound {
696 fn eq(&self, other: &Self) -> bool {
697 self.get().eq(&other.get())
698 }
699}
700
701impl Eq for ArrayBound {}
702
703impl ArrayBound {
704 pub fn new(array: ArrayRef, index: usize) -> Self {
706 Self {
707 ord: ArrayOrd::new(array.as_ref()),
708 raw: array,
709 index,
710 }
711 }
712
713 pub fn get(&self) -> ArrayIdx<'_> {
715 self.ord.at(self.index)
716 }
717
718 pub fn to_proto_lower(&self, max_len: usize) -> Option<ProtoArrayData> {
722 let indices = UInt64Array::from_value(u64::usize_as(self.index), 1);
725 let taken = arrow::compute::take(self.raw.as_ref(), &indices, None).ok()?;
726 let array_data = taken.into_data();
727
728 let mut proto = array_data.into_proto();
729 let original_len = proto.encoded_len();
730 if original_len <= max_len {
731 return Some(proto);
732 }
733
734 let mut data_type = proto.data_type.take()?;
735 maybe_trim_proto(&mut data_type, &mut proto, max_len);
736 proto.data_type = Some(data_type);
737
738 if cfg!(debug_assertions) {
739 let array: ArrayData = proto
740 .clone()
741 .into_rust()
742 .expect("trimmed array data can still be decoded");
743 assert_eq!(array.len(), 1);
744 let new_bound = Self::new(make_array(array), 0);
745 assert!(
746 new_bound.get() <= self.get(),
747 "trimmed bound should be comparable to and no larger than the original data"
748 )
749 }
750
751 if proto.encoded_len() <= max_len {
752 Some(proto)
753 } else {
754 None
755 }
756 }
757}
758
759fn maybe_trim_proto(data_type: &mut proto::DataType, body: &mut ProtoArrayData, max_len: usize) {
767 assert!(body.data_type.is_none(), "expected separate data type");
768 let encoded_len = data_type.encoded_len() + body.encoded_len();
770 match &mut data_type.kind {
771 Some(data_type::Kind::Struct(data_type::Struct { children: fields })) => {
772 let mut struct_len = encoded_len;
774 while struct_len > max_len {
775 let Some(mut child) = body.children.pop() else {
776 break;
777 };
778 let Some(mut field) = fields.pop() else { break };
779
780 struct_len -= field.encoded_len() + child.encoded_len();
781 if let Some(remaining_len) = max_len.checked_sub(struct_len) {
782 let Some(field_type) = field.data_type.as_mut() else {
785 break;
786 };
787 maybe_trim_proto(field_type, &mut child, remaining_len);
788 if field.encoded_len() + child.encoded_len() <= remaining_len {
789 fields.push(field);
790 body.children.push(child);
791 }
792 break;
793 }
794 }
795 }
796 _ => {}
797 };
798}
799
800#[cfg(test)]
801mod tests {
802 use crate::arrow::{ArrayBound, ArrayOrd};
803 use arrow::array::{
804 ArrayRef, AsArray, BooleanArray, StringArray, StructArray, UInt64Array, make_array,
805 };
806 use arrow::datatypes::{DataType, Field, Fields};
807 use mz_ore::assert_none;
808 use mz_proto::ProtoType;
809 use std::sync::Arc;
810
811 #[mz_ore::test]
812 fn trim_proto() {
813 let nested_fields: Fields = vec![Field::new("a", DataType::UInt64, true)].into();
814 let array: ArrayRef = Arc::new(StructArray::new(
815 vec![
816 Field::new("a", DataType::UInt64, true),
817 Field::new("b", DataType::Utf8, true),
818 Field::new_struct("c", nested_fields.clone(), true),
819 ]
820 .into(),
821 vec![
822 Arc::new(UInt64Array::from_iter_values([1])),
823 Arc::new(StringArray::from_iter_values(["large".repeat(50)])),
824 Arc::new(StructArray::new_null(nested_fields, 1)),
825 ],
826 None,
827 ));
828 let bound = ArrayBound::new(array, 0);
829
830 assert_none!(bound.to_proto_lower(0));
831 assert_none!(bound.to_proto_lower(1));
832
833 let proto = bound
834 .to_proto_lower(100)
835 .expect("can fit something in less than 100 bytes");
836 let array = make_array(proto.into_rust().expect("valid proto"));
837 assert_eq!(
838 array.as_struct().column_names().as_slice(),
839 &["a"],
840 "only the first column should fit"
841 );
842
843 let proto = bound
844 .to_proto_lower(1000)
845 .expect("can fit everything in less than 1000 bytes");
846 let array = make_array(proto.into_rust().expect("valid proto"));
847 assert_eq!(
848 array.as_struct().column_names().as_slice(),
849 &["a", "b", "c"],
850 "all columns should fit"
851 )
852 }
853
854 #[mz_ore::test]
855 fn struct_ord() {
856 let prefix = StructArray::new(
857 vec![Field::new("a", DataType::UInt64, true)].into(),
858 vec![Arc::new(UInt64Array::from_iter_values([1, 3, 5]))],
859 None,
860 );
861 let full = StructArray::new(
862 vec![
863 Field::new("a", DataType::UInt64, true),
864 Field::new("b", DataType::Utf8, true),
865 ]
866 .into(),
867 vec![
868 Arc::new(UInt64Array::from_iter_values([2, 3, 4])),
869 Arc::new(StringArray::from_iter_values(["a", "b", "c"])),
870 ],
871 None,
872 );
873 let prefix_ord = ArrayOrd::new(&prefix);
874 let full_ord = ArrayOrd::new(&full);
875
876 assert!(prefix_ord.at(0) < full_ord.at(0), "(1) < (2, 'a')");
879 assert!(prefix_ord.at(1) < full_ord.at(1), "(3) < (3, 'b')");
880 assert!(prefix_ord.at(2) > full_ord.at(2), "(5) < (4, 'c')");
881 }
882
883 #[mz_ore::test]
884 #[should_panic(expected = "array types did not match")]
885 fn struct_ord_incompat() {
886 let string = StructArray::new(
889 vec![
890 Field::new("a", DataType::UInt64, true),
891 Field::new("b", DataType::Utf8, true),
892 ]
893 .into(),
894 vec![
895 Arc::new(UInt64Array::from_iter_values([1])),
896 Arc::new(StringArray::from_iter_values(["a"])),
897 ],
898 None,
899 );
900 let boolean = StructArray::new(
901 vec![
902 Field::new("a", DataType::UInt64, true),
903 Field::new("b", DataType::Boolean, true),
904 ]
905 .into(),
906 vec![
907 Arc::new(UInt64Array::from_iter_values([1])),
908 Arc::new(BooleanArray::from_iter([Some(true)])),
909 ],
910 None,
911 );
912 let string_ord = ArrayOrd::new(&string);
913 let bool_ord = ArrayOrd::new(&boolean);
914
915 assert!(string_ord.at(0) < bool_ord.at(0));
917 }
918}