1use crate::{new_empty_array, Array, ArrayRef, StructArray};
22use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef};
23use std::ops::Index;
24use std::sync::Arc;
25
26pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
30 fn schema(&self) -> SchemaRef;
35
36 #[deprecated(
38 since = "2.0.0",
39 note = "This method is deprecated in favour of `next` from the trait Iterator."
40 )]
41 fn next_batch(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
42 self.next().transpose()
43 }
44}
45
46impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
47 fn schema(&self) -> SchemaRef {
48 self.as_ref().schema()
49 }
50}
51
52pub trait RecordBatchWriter {
54 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
56
57 fn close(self) -> Result<(), ArrowError>;
59}
60
61#[macro_export]
87macro_rules! create_array {
88 (@from Boolean) => { $crate::BooleanArray };
90 (@from Int8) => { $crate::Int8Array };
91 (@from Int16) => { $crate::Int16Array };
92 (@from Int32) => { $crate::Int32Array };
93 (@from Int64) => { $crate::Int64Array };
94 (@from UInt8) => { $crate::UInt8Array };
95 (@from UInt16) => { $crate::UInt16Array };
96 (@from UInt32) => { $crate::UInt32Array };
97 (@from UInt64) => { $crate::UInt64Array };
98 (@from Float16) => { $crate::Float16Array };
99 (@from Float32) => { $crate::Float32Array };
100 (@from Float64) => { $crate::Float64Array };
101 (@from Utf8) => { $crate::StringArray };
102 (@from Utf8View) => { $crate::StringViewArray };
103 (@from LargeUtf8) => { $crate::LargeStringArray };
104 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
105 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
106 (@from Second) => { $crate::TimestampSecondArray };
107 (@from Millisecond) => { $crate::TimestampMillisecondArray };
108 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
109 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
110 (@from Second32) => { $crate::Time32SecondArray };
111 (@from Millisecond32) => { $crate::Time32MillisecondArray };
112 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
113 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
114 (@from DurationSecond) => { $crate::DurationSecondArray };
115 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
116 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
117 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
118 (@from Decimal128) => { $crate::Decimal128Array };
119 (@from Decimal256) => { $crate::Decimal256Array };
120 (@from TimestampSecond) => { $crate::TimestampSecondArray };
121 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
122 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
123 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
124
125 (@from $ty: ident) => {
126 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
127 };
128
129 (Null, $size: expr) => {
130 std::sync::Arc::new($crate::NullArray::new($size))
131 };
132
133 (Binary, [$($values: expr),*]) => {
134 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
135 };
136
137 (LargeBinary, [$($values: expr),*]) => {
138 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
139 };
140
141 ($ty: tt, [$($values: expr),*]) => {
142 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
143 };
144}
145
146#[macro_export]
163macro_rules! record_batch {
164 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
165 {
166 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
167 $(
168 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
169 )*
170 ]));
171
172 let batch = $crate::RecordBatch::try_new(
173 schema,
174 vec![$(
175 $crate::create_array!($type, [$($values),*]),
176 )*]
177 );
178
179 batch
180 }
181 }
182}
183
184#[derive(Clone, Debug, PartialEq)]
208pub struct RecordBatch {
209 schema: SchemaRef,
210 columns: Vec<Arc<dyn Array>>,
211
212 row_count: usize,
216}
217
218impl RecordBatch {
219 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
247 let options = RecordBatchOptions::new();
248 Self::try_new_impl(schema, columns, &options)
249 }
250
251 pub fn try_new_with_options(
256 schema: SchemaRef,
257 columns: Vec<ArrayRef>,
258 options: &RecordBatchOptions,
259 ) -> Result<Self, ArrowError> {
260 Self::try_new_impl(schema, columns, options)
261 }
262
263 pub fn new_empty(schema: SchemaRef) -> Self {
265 let columns = schema
266 .fields()
267 .iter()
268 .map(|field| new_empty_array(field.data_type()))
269 .collect();
270
271 RecordBatch {
272 schema,
273 columns,
274 row_count: 0,
275 }
276 }
277
278 fn try_new_impl(
281 schema: SchemaRef,
282 columns: Vec<ArrayRef>,
283 options: &RecordBatchOptions,
284 ) -> Result<Self, ArrowError> {
285 if schema.fields().len() != columns.len() {
287 return Err(ArrowError::InvalidArgumentError(format!(
288 "number of columns({}) must match number of fields({}) in schema",
289 columns.len(),
290 schema.fields().len(),
291 )));
292 }
293
294 let row_count = options
295 .row_count
296 .or_else(|| columns.first().map(|col| col.len()))
297 .ok_or_else(|| {
298 ArrowError::InvalidArgumentError(
299 "must either specify a row count or at least one column".to_string(),
300 )
301 })?;
302
303 for (c, f) in columns.iter().zip(&schema.fields) {
304 if !f.is_nullable() && c.null_count() > 0 {
305 return Err(ArrowError::InvalidArgumentError(format!(
306 "Column '{}' is declared as non-nullable but contains null values",
307 f.name()
308 )));
309 }
310 }
311
312 if columns.iter().any(|c| c.len() != row_count) {
314 let err = match options.row_count {
315 Some(_) => "all columns in a record batch must have the specified row count",
316 None => "all columns in a record batch must have the same length",
317 };
318 return Err(ArrowError::InvalidArgumentError(err.to_string()));
319 }
320
321 let type_not_match = if options.match_field_names {
324 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
325 } else {
326 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
327 !col_type.equals_datatype(field_type)
328 }
329 };
330
331 let not_match = columns
333 .iter()
334 .zip(schema.fields().iter())
335 .map(|(col, field)| (col.data_type(), field.data_type()))
336 .enumerate()
337 .find(type_not_match);
338
339 if let Some((i, (col_type, field_type))) = not_match {
340 return Err(ArrowError::InvalidArgumentError(format!(
341 "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
342 }
343
344 Ok(RecordBatch {
345 schema,
346 columns,
347 row_count,
348 })
349 }
350
351 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
356 if !schema.contains(self.schema.as_ref()) {
357 return Err(ArrowError::SchemaError(format!(
358 "target schema is not superset of current schema target={schema} current={}",
359 self.schema
360 )));
361 }
362
363 Ok(Self {
364 schema,
365 columns: self.columns,
366 row_count: self.row_count,
367 })
368 }
369
370 pub fn schema(&self) -> SchemaRef {
372 self.schema.clone()
373 }
374
375 pub fn schema_ref(&self) -> &SchemaRef {
377 &self.schema
378 }
379
380 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
382 let projected_schema = self.schema.project(indices)?;
383 let batch_fields = indices
384 .iter()
385 .map(|f| {
386 self.columns.get(*f).cloned().ok_or_else(|| {
387 ArrowError::SchemaError(format!(
388 "project index {} out of bounds, max field {}",
389 f,
390 self.columns.len()
391 ))
392 })
393 })
394 .collect::<Result<Vec<_>, _>>()?;
395
396 RecordBatch::try_new_with_options(
397 SchemaRef::new(projected_schema),
398 batch_fields,
399 &RecordBatchOptions {
400 match_field_names: true,
401 row_count: Some(self.row_count),
402 },
403 )
404 }
405
406 pub fn num_columns(&self) -> usize {
425 self.columns.len()
426 }
427
428 pub fn num_rows(&self) -> usize {
447 self.row_count
448 }
449
450 pub fn column(&self, index: usize) -> &ArrayRef {
456 &self.columns[index]
457 }
458
459 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
461 self.schema()
462 .column_with_name(name)
463 .map(|(index, _)| &self.columns[index])
464 }
465
466 pub fn columns(&self) -> &[ArrayRef] {
468 &self.columns[..]
469 }
470
471 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
499 let mut builder = SchemaBuilder::from(self.schema.as_ref());
500 builder.remove(index);
501 self.schema = Arc::new(builder.finish());
502 self.columns.remove(index)
503 }
504
505 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
512 assert!((offset + length) <= self.num_rows());
513
514 let columns = self
515 .columns()
516 .iter()
517 .map(|column| column.slice(offset, length))
518 .collect();
519
520 Self {
521 schema: self.schema.clone(),
522 columns,
523 row_count: length,
524 }
525 }
526
527 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
564 where
565 I: IntoIterator<Item = (F, ArrayRef)>,
566 F: AsRef<str>,
567 {
568 let iter = value.into_iter().map(|(field_name, array)| {
572 let nullable = array.null_count() > 0;
573 (field_name, array, nullable)
574 });
575
576 Self::try_from_iter_with_nullable(iter)
577 }
578
579 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
601 where
602 I: IntoIterator<Item = (F, ArrayRef, bool)>,
603 F: AsRef<str>,
604 {
605 let iter = value.into_iter();
606 let capacity = iter.size_hint().0;
607 let mut schema = SchemaBuilder::with_capacity(capacity);
608 let mut columns = Vec::with_capacity(capacity);
609
610 for (field_name, array, nullable) in iter {
611 let field_name = field_name.as_ref();
612 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
613 columns.push(array);
614 }
615
616 let schema = Arc::new(schema.finish());
617 RecordBatch::try_new(schema, columns)
618 }
619
620 pub fn get_array_memory_size(&self) -> usize {
627 self.columns()
628 .iter()
629 .map(|array| array.get_array_memory_size())
630 .sum()
631 }
632}
633
634#[derive(Debug)]
636#[non_exhaustive]
637pub struct RecordBatchOptions {
638 pub match_field_names: bool,
640
641 pub row_count: Option<usize>,
643}
644
645impl RecordBatchOptions {
646 pub fn new() -> Self {
648 Self {
649 match_field_names: true,
650 row_count: None,
651 }
652 }
653 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
655 self.row_count = row_count;
656 self
657 }
658 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
660 self.match_field_names = match_field_names;
661 self
662 }
663}
664impl Default for RecordBatchOptions {
665 fn default() -> Self {
666 Self::new()
667 }
668}
669impl From<StructArray> for RecordBatch {
670 fn from(value: StructArray) -> Self {
671 let row_count = value.len();
672 let (fields, columns, nulls) = value.into_parts();
673 assert_eq!(
674 nulls.map(|n| n.null_count()).unwrap_or_default(),
675 0,
676 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
677 );
678
679 RecordBatch {
680 schema: Arc::new(Schema::new(fields)),
681 row_count,
682 columns,
683 }
684 }
685}
686
687impl From<&StructArray> for RecordBatch {
688 fn from(struct_array: &StructArray) -> Self {
689 struct_array.clone().into()
690 }
691}
692
693impl Index<&str> for RecordBatch {
694 type Output = ArrayRef;
695
696 fn index(&self, name: &str) -> &Self::Output {
702 self.column_by_name(name).unwrap()
703 }
704}
705
706pub struct RecordBatchIterator<I>
732where
733 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
734{
735 inner: I::IntoIter,
736 inner_schema: SchemaRef,
737}
738
739impl<I> RecordBatchIterator<I>
740where
741 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
742{
743 pub fn new(iter: I, schema: SchemaRef) -> Self {
747 Self {
748 inner: iter.into_iter(),
749 inner_schema: schema,
750 }
751 }
752}
753
754impl<I> Iterator for RecordBatchIterator<I>
755where
756 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
757{
758 type Item = I::Item;
759
760 fn next(&mut self) -> Option<Self::Item> {
761 self.inner.next()
762 }
763
764 fn size_hint(&self) -> (usize, Option<usize>) {
765 self.inner.size_hint()
766 }
767}
768
769impl<I> RecordBatchReader for RecordBatchIterator<I>
770where
771 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
772{
773 fn schema(&self) -> SchemaRef {
774 self.inner_schema.clone()
775 }
776}
777
778#[cfg(test)]
779mod tests {
780 use std::collections::HashMap;
781
782 use super::*;
783 use crate::{
784 BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
785 };
786 use arrow_buffer::{Buffer, ToByteSlice};
787 use arrow_data::{ArrayData, ArrayDataBuilder};
788 use arrow_schema::Fields;
789
790 #[test]
791 fn create_record_batch() {
792 let schema = Schema::new(vec![
793 Field::new("a", DataType::Int32, false),
794 Field::new("b", DataType::Utf8, false),
795 ]);
796
797 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
798 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
799
800 let record_batch =
801 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
802 check_batch(record_batch, 5)
803 }
804
805 #[test]
806 fn create_string_view_record_batch() {
807 let schema = Schema::new(vec![
808 Field::new("a", DataType::Int32, false),
809 Field::new("b", DataType::Utf8View, false),
810 ]);
811
812 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
813 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
814
815 let record_batch =
816 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
817
818 assert_eq!(5, record_batch.num_rows());
819 assert_eq!(2, record_batch.num_columns());
820 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
821 assert_eq!(
822 &DataType::Utf8View,
823 record_batch.schema().field(1).data_type()
824 );
825 assert_eq!(5, record_batch.column(0).len());
826 assert_eq!(5, record_batch.column(1).len());
827 }
828
829 #[test]
830 fn byte_size_should_not_regress() {
831 let schema = Schema::new(vec![
832 Field::new("a", DataType::Int32, false),
833 Field::new("b", DataType::Utf8, false),
834 ]);
835
836 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
837 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
838
839 let record_batch =
840 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
841 assert_eq!(record_batch.get_array_memory_size(), 364);
842 }
843
844 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
845 assert_eq!(num_rows, record_batch.num_rows());
846 assert_eq!(2, record_batch.num_columns());
847 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
848 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
849 assert_eq!(num_rows, record_batch.column(0).len());
850 assert_eq!(num_rows, record_batch.column(1).len());
851 }
852
853 #[test]
854 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
855 fn create_record_batch_slice() {
856 let schema = Schema::new(vec![
857 Field::new("a", DataType::Int32, false),
858 Field::new("b", DataType::Utf8, false),
859 ]);
860 let expected_schema = schema.clone();
861
862 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
863 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
864
865 let record_batch =
866 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
867
868 let offset = 2;
869 let length = 5;
870 let record_batch_slice = record_batch.slice(offset, length);
871
872 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
873 check_batch(record_batch_slice, 5);
874
875 let offset = 2;
876 let length = 0;
877 let record_batch_slice = record_batch.slice(offset, length);
878
879 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
880 check_batch(record_batch_slice, 0);
881
882 let offset = 2;
883 let length = 10;
884 let _record_batch_slice = record_batch.slice(offset, length);
885 }
886
887 #[test]
888 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
889 fn create_record_batch_slice_empty_batch() {
890 let schema = Schema::empty();
891
892 let record_batch = RecordBatch::new_empty(Arc::new(schema));
893
894 let offset = 0;
895 let length = 0;
896 let record_batch_slice = record_batch.slice(offset, length);
897 assert_eq!(0, record_batch_slice.schema().fields().len());
898
899 let offset = 1;
900 let length = 2;
901 let _record_batch_slice = record_batch.slice(offset, length);
902 }
903
904 #[test]
905 fn create_record_batch_try_from_iter() {
906 let a: ArrayRef = Arc::new(Int32Array::from(vec![
907 Some(1),
908 Some(2),
909 None,
910 Some(4),
911 Some(5),
912 ]));
913 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
914
915 let record_batch =
916 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
917
918 let expected_schema = Schema::new(vec![
919 Field::new("a", DataType::Int32, true),
920 Field::new("b", DataType::Utf8, false),
921 ]);
922 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
923 check_batch(record_batch, 5);
924 }
925
926 #[test]
927 fn create_record_batch_try_from_iter_with_nullable() {
928 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
929 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
930
931 let record_batch =
933 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
934 .expect("valid conversion");
935
936 let expected_schema = Schema::new(vec![
937 Field::new("a", DataType::Int32, false),
938 Field::new("b", DataType::Utf8, true),
939 ]);
940 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
941 check_batch(record_batch, 5);
942 }
943
944 #[test]
945 fn create_record_batch_schema_mismatch() {
946 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
947
948 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
949
950 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
951 assert!(batch.is_err());
952 }
953
954 #[test]
955 fn create_record_batch_field_name_mismatch() {
956 let fields = vec![
957 Field::new("a1", DataType::Int32, false),
958 Field::new_list("a2", Field::new("item", DataType::Int8, false), false),
959 ];
960 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
961
962 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
963 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
964 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
965 "array",
966 DataType::Int8,
967 false,
968 ))))
969 .add_child_data(a2_child.into_data())
970 .len(2)
971 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
972 .build()
973 .unwrap();
974 let a2: ArrayRef = Arc::new(ListArray::from(a2));
975 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
976 Field::new("aa1", DataType::Int32, false),
977 Field::new("a2", a2.data_type().clone(), false),
978 ])))
979 .add_child_data(a1.into_data())
980 .add_child_data(a2.into_data())
981 .len(2)
982 .build()
983 .unwrap();
984 let a: ArrayRef = Arc::new(StructArray::from(a));
985
986 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
988 assert!(batch.is_err());
989
990 let options = RecordBatchOptions {
992 match_field_names: false,
993 row_count: None,
994 };
995 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
996 assert!(batch.is_ok());
997 }
998
999 #[test]
1000 fn create_record_batch_record_mismatch() {
1001 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1002
1003 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1004 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1005
1006 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1007 assert!(batch.is_err());
1008 }
1009
1010 #[test]
1011 fn create_record_batch_from_struct_array() {
1012 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1013 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1014 let struct_array = StructArray::from(vec![
1015 (
1016 Arc::new(Field::new("b", DataType::Boolean, false)),
1017 boolean.clone() as ArrayRef,
1018 ),
1019 (
1020 Arc::new(Field::new("c", DataType::Int32, false)),
1021 int.clone() as ArrayRef,
1022 ),
1023 ]);
1024
1025 let batch = RecordBatch::from(&struct_array);
1026 assert_eq!(2, batch.num_columns());
1027 assert_eq!(4, batch.num_rows());
1028 assert_eq!(
1029 struct_array.data_type(),
1030 &DataType::Struct(batch.schema().fields().clone())
1031 );
1032 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1033 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1034 }
1035
1036 #[test]
1037 fn record_batch_equality() {
1038 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1039 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1040 let schema1 = Schema::new(vec![
1041 Field::new("id", DataType::Int32, false),
1042 Field::new("val", DataType::Int32, false),
1043 ]);
1044
1045 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1046 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1047 let schema2 = Schema::new(vec![
1048 Field::new("id", DataType::Int32, false),
1049 Field::new("val", DataType::Int32, false),
1050 ]);
1051
1052 let batch1 = RecordBatch::try_new(
1053 Arc::new(schema1),
1054 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1055 )
1056 .unwrap();
1057
1058 let batch2 = RecordBatch::try_new(
1059 Arc::new(schema2),
1060 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1061 )
1062 .unwrap();
1063
1064 assert_eq!(batch1, batch2);
1065 }
1066
1067 #[test]
1069 fn record_batch_index_access() {
1070 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1071 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1072 let schema1 = Schema::new(vec![
1073 Field::new("id", DataType::Int32, false),
1074 Field::new("val", DataType::Int32, false),
1075 ]);
1076 let record_batch =
1077 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1078
1079 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1080 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1081 }
1082
1083 #[test]
1084 fn record_batch_vals_ne() {
1085 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1086 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1087 let schema1 = Schema::new(vec![
1088 Field::new("id", DataType::Int32, false),
1089 Field::new("val", DataType::Int32, false),
1090 ]);
1091
1092 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1093 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1094 let schema2 = Schema::new(vec![
1095 Field::new("id", DataType::Int32, false),
1096 Field::new("val", DataType::Int32, false),
1097 ]);
1098
1099 let batch1 = RecordBatch::try_new(
1100 Arc::new(schema1),
1101 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1102 )
1103 .unwrap();
1104
1105 let batch2 = RecordBatch::try_new(
1106 Arc::new(schema2),
1107 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1108 )
1109 .unwrap();
1110
1111 assert_ne!(batch1, batch2);
1112 }
1113
1114 #[test]
1115 fn record_batch_column_names_ne() {
1116 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1117 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1118 let schema1 = Schema::new(vec![
1119 Field::new("id", DataType::Int32, false),
1120 Field::new("val", DataType::Int32, false),
1121 ]);
1122
1123 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1124 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1125 let schema2 = Schema::new(vec![
1126 Field::new("id", DataType::Int32, false),
1127 Field::new("num", DataType::Int32, false),
1128 ]);
1129
1130 let batch1 = RecordBatch::try_new(
1131 Arc::new(schema1),
1132 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1133 )
1134 .unwrap();
1135
1136 let batch2 = RecordBatch::try_new(
1137 Arc::new(schema2),
1138 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1139 )
1140 .unwrap();
1141
1142 assert_ne!(batch1, batch2);
1143 }
1144
1145 #[test]
1146 fn record_batch_column_number_ne() {
1147 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1148 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1149 let schema1 = Schema::new(vec![
1150 Field::new("id", DataType::Int32, false),
1151 Field::new("val", DataType::Int32, false),
1152 ]);
1153
1154 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1155 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1156 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1157 let schema2 = Schema::new(vec![
1158 Field::new("id", DataType::Int32, false),
1159 Field::new("val", DataType::Int32, false),
1160 Field::new("num", DataType::Int32, false),
1161 ]);
1162
1163 let batch1 = RecordBatch::try_new(
1164 Arc::new(schema1),
1165 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1166 )
1167 .unwrap();
1168
1169 let batch2 = RecordBatch::try_new(
1170 Arc::new(schema2),
1171 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1172 )
1173 .unwrap();
1174
1175 assert_ne!(batch1, batch2);
1176 }
1177
1178 #[test]
1179 fn record_batch_row_count_ne() {
1180 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1181 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1182 let schema1 = Schema::new(vec![
1183 Field::new("id", DataType::Int32, false),
1184 Field::new("val", DataType::Int32, false),
1185 ]);
1186
1187 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1188 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1189 let schema2 = Schema::new(vec![
1190 Field::new("id", DataType::Int32, false),
1191 Field::new("num", DataType::Int32, false),
1192 ]);
1193
1194 let batch1 = RecordBatch::try_new(
1195 Arc::new(schema1),
1196 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1197 )
1198 .unwrap();
1199
1200 let batch2 = RecordBatch::try_new(
1201 Arc::new(schema2),
1202 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1203 )
1204 .unwrap();
1205
1206 assert_ne!(batch1, batch2);
1207 }
1208
1209 #[test]
1210 fn project() {
1211 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1212 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1213 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1214
1215 let record_batch =
1216 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1217 .expect("valid conversion");
1218
1219 let expected =
1220 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1221
1222 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1223 }
1224
1225 #[test]
1226 fn project_empty() {
1227 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1228
1229 let record_batch =
1230 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1231
1232 let expected = RecordBatch::try_new_with_options(
1233 Arc::new(Schema::empty()),
1234 vec![],
1235 &RecordBatchOptions {
1236 match_field_names: true,
1237 row_count: Some(3),
1238 },
1239 )
1240 .expect("valid conversion");
1241
1242 assert_eq!(expected, record_batch.project(&[]).unwrap());
1243 }
1244
1245 #[test]
1246 fn test_no_column_record_batch() {
1247 let schema = Arc::new(Schema::empty());
1248
1249 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1250 assert!(err
1251 .to_string()
1252 .contains("must either specify a row count or at least one column"));
1253
1254 let options = RecordBatchOptions::new().with_row_count(Some(10));
1255
1256 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1257 assert_eq!(ok.num_rows(), 10);
1258
1259 let a = ok.slice(2, 5);
1260 assert_eq!(a.num_rows(), 5);
1261
1262 let b = ok.slice(5, 0);
1263 assert_eq!(b.num_rows(), 0);
1264
1265 assert_ne!(a, b);
1266 assert_eq!(b, RecordBatch::new_empty(schema))
1267 }
1268
1269 #[test]
1270 fn test_nulls_in_non_nullable_field() {
1271 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1272 let maybe_batch = RecordBatch::try_new(
1273 schema,
1274 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1275 );
1276 assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1277 }
1278 #[test]
1279 fn test_record_batch_options() {
1280 let options = RecordBatchOptions::new()
1281 .with_match_field_names(false)
1282 .with_row_count(Some(20));
1283 assert!(!options.match_field_names);
1284 assert_eq!(options.row_count.unwrap(), 20)
1285 }
1286
1287 #[test]
1288 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1289 fn test_from_struct() {
1290 let s = StructArray::from(ArrayData::new_null(
1291 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1293 2,
1294 ));
1295 let _ = RecordBatch::from(s);
1296 }
1297
1298 #[test]
1299 fn test_with_schema() {
1300 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1301 let required_schema = Arc::new(required_schema);
1302 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1303 let nullable_schema = Arc::new(nullable_schema);
1304
1305 let batch = RecordBatch::try_new(
1306 required_schema.clone(),
1307 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1308 )
1309 .unwrap();
1310
1311 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1313
1314 batch.clone().with_schema(required_schema).unwrap_err();
1316
1317 let metadata = vec![("foo".to_string(), "bar".to_string())]
1319 .into_iter()
1320 .collect();
1321 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1322 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1323
1324 batch.with_schema(nullable_schema).unwrap_err();
1326 }
1327
1328 #[test]
1329 fn test_boxed_reader() {
1330 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1333 let schema = Arc::new(schema);
1334
1335 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1336 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1337
1338 fn get_size(reader: impl RecordBatchReader) -> usize {
1339 reader.size_hint().0
1340 }
1341
1342 let size = get_size(reader);
1343 assert_eq!(size, 0);
1344 }
1345
1346 #[test]
1347 fn test_remove_column_maintains_schema_metadata() {
1348 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1349 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1350
1351 let mut metadata = HashMap::new();
1352 metadata.insert("foo".to_string(), "bar".to_string());
1353 let schema = Schema::new(vec![
1354 Field::new("id", DataType::Int32, false),
1355 Field::new("bool", DataType::Boolean, false),
1356 ])
1357 .with_metadata(metadata);
1358
1359 let mut batch = RecordBatch::try_new(
1360 Arc::new(schema),
1361 vec![Arc::new(id_array), Arc::new(bool_array)],
1362 )
1363 .unwrap();
1364
1365 let _removed_column = batch.remove_column(0);
1366 assert_eq!(batch.schema().metadata().len(), 1);
1367 assert_eq!(
1368 batch.schema().metadata().get("foo").unwrap().as_str(),
1369 "bar"
1370 );
1371 }
1372}