arrow_select/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines concat kernel for `ArrayRef`
19//!
20//! Example:
21//!
22//! ```
23//! use arrow_array::{ArrayRef, StringArray};
24//! use arrow_select::concat::concat;
25//!
26//! let arr = concat(&[
27//!     &StringArray::from(vec!["hello", "world"]),
28//!     &StringArray::from(vec!["!"]),
29//! ]).unwrap();
30//! assert_eq!(arr.len(), 3);
31//! ```
32
33use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
34use arrow_array::cast::AsArray;
35use arrow_array::types::*;
36use arrow_array::*;
37use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
38use arrow_data::transform::{Capacities, MutableArrayData};
39use arrow_schema::{ArrowError, DataType, SchemaRef};
40use std::sync::Arc;
41
42fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
43    let mut item_capacity = 0;
44    let mut bytes_capacity = 0;
45    for array in arrays {
46        let a = array.as_bytes::<T>();
47
48        // Guaranteed to always have at least one element
49        let offsets = a.value_offsets();
50        bytes_capacity += offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize();
51        item_capacity += a.len()
52    }
53
54    Capacities::Binary(item_capacity, Some(bytes_capacity))
55}
56
57fn fixed_size_list_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
58    if let DataType::FixedSizeList(f, _) = data_type {
59        let item_capacity = arrays.iter().map(|a| a.len()).sum();
60        let child_data_type = f.data_type();
61        match child_data_type {
62            // These types should match the types that `get_capacity`
63            // has special handling for.
64            DataType::Utf8
65            | DataType::LargeUtf8
66            | DataType::Binary
67            | DataType::LargeBinary
68            | DataType::FixedSizeList(_, _) => {
69                let values: Vec<&dyn arrow_array::Array> = arrays
70                    .iter()
71                    .map(|a| a.as_fixed_size_list().values().as_ref())
72                    .collect();
73                Capacities::List(
74                    item_capacity,
75                    Some(Box::new(get_capacity(&values, child_data_type))),
76                )
77            }
78            _ => Capacities::Array(item_capacity),
79        }
80    } else {
81        unreachable!("illegal data type for fixed size list")
82    }
83}
84
85fn concat_dictionaries<K: ArrowDictionaryKeyType>(
86    arrays: &[&dyn Array],
87) -> Result<ArrayRef, ArrowError> {
88    let mut output_len = 0;
89    let dictionaries: Vec<_> = arrays
90        .iter()
91        .map(|x| x.as_dictionary::<K>())
92        .inspect(|d| output_len += d.len())
93        .collect();
94
95    if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
96        return concat_fallback(arrays, Capacities::Array(output_len));
97    }
98
99    let merged = merge_dictionary_values(&dictionaries, None)?;
100
101    // Recompute keys
102    let mut key_values = Vec::with_capacity(output_len);
103
104    let mut has_nulls = false;
105    for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
106        has_nulls |= d.null_count() != 0;
107        for key in d.keys().values() {
108            // Use get to safely handle nulls
109            key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default())
110        }
111    }
112
113    let nulls = has_nulls.then(|| {
114        let mut nulls = BooleanBufferBuilder::new(output_len);
115        for d in &dictionaries {
116            match d.nulls() {
117                Some(n) => nulls.append_buffer(n.inner()),
118                None => nulls.append_n(d.len(), true),
119            }
120        }
121        NullBuffer::new(nulls.finish())
122    });
123
124    let keys = PrimitiveArray::<K>::new(key_values.into(), nulls);
125    // Sanity check
126    assert_eq!(keys.len(), output_len);
127
128    let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };
129    Ok(Arc::new(array))
130}
131
132macro_rules! dict_helper {
133    ($t:ty, $arrays:expr) => {
134        return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
135    };
136}
137
138fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
139    match data_type {
140        DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
141        DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
142        DataType::Binary => binary_capacity::<BinaryType>(arrays),
143        DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
144        DataType::FixedSizeList(_, _) => fixed_size_list_capacity(arrays, data_type),
145        _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
146    }
147}
148
149/// Concatenate multiple [Array] of the same type into a single [ArrayRef].
150pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
151    if arrays.is_empty() {
152        return Err(ArrowError::ComputeError(
153            "concat requires input of at least one array".to_string(),
154        ));
155    } else if arrays.len() == 1 {
156        let array = arrays[0];
157        return Ok(array.slice(0, array.len()));
158    }
159
160    let d = arrays[0].data_type();
161    if arrays.iter().skip(1).any(|array| array.data_type() != d) {
162        return Err(ArrowError::InvalidArgumentError(
163            "It is not possible to concatenate arrays of different data types.".to_string(),
164        ));
165    }
166    if let DataType::Dictionary(k, _) = d {
167        downcast_integer! {
168            k.as_ref() => (dict_helper, arrays),
169            _ => unreachable!("illegal dictionary key type {k}")
170        };
171    } else {
172        let capacity = get_capacity(arrays, d);
173        concat_fallback(arrays, capacity)
174    }
175}
176
177/// Concatenates arrays using MutableArrayData
178///
179/// This will naively concatenate dictionaries
180fn concat_fallback(arrays: &[&dyn Array], capacity: Capacities) -> Result<ArrayRef, ArrowError> {
181    let array_data: Vec<_> = arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
182    let array_data = array_data.iter().collect();
183    let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);
184
185    for (i, a) in arrays.iter().enumerate() {
186        mutable.extend(i, 0, a.len())
187    }
188
189    Ok(make_array(mutable.freeze()))
190}
191
192/// Concatenates `batches` together into a single [`RecordBatch`].
193///
194/// The output batch has the specified `schemas`; The schema of the
195/// input are ignored.
196///
197/// Returns an error if the types of underlying arrays are different.
198pub fn concat_batches<'a>(
199    schema: &SchemaRef,
200    input_batches: impl IntoIterator<Item = &'a RecordBatch>,
201) -> Result<RecordBatch, ArrowError> {
202    // When schema is empty, sum the number of the rows of all batches
203    if schema.fields().is_empty() {
204        let num_rows: usize = input_batches.into_iter().map(RecordBatch::num_rows).sum();
205        let mut options = RecordBatchOptions::default();
206        options.row_count = Some(num_rows);
207        return RecordBatch::try_new_with_options(schema.clone(), vec![], &options);
208    }
209
210    let batches: Vec<&RecordBatch> = input_batches.into_iter().collect();
211    if batches.is_empty() {
212        return Ok(RecordBatch::new_empty(schema.clone()));
213    }
214    let field_num = schema.fields().len();
215    let mut arrays = Vec::with_capacity(field_num);
216    for i in 0..field_num {
217        let array = concat(
218            &batches
219                .iter()
220                .map(|batch| batch.column(i).as_ref())
221                .collect::<Vec<_>>(),
222        )?;
223        arrays.push(array);
224    }
225    RecordBatch::try_new(schema.clone(), arrays)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use arrow_array::builder::StringDictionaryBuilder;
232    use arrow_schema::{Field, Schema};
233
234    #[test]
235    fn test_concat_empty_vec() {
236        let re = concat(&[]);
237        assert!(re.is_err());
238    }
239
240    #[test]
241    fn test_concat_batches_no_columns() {
242        // Test concat using empty schema / batches without columns
243        let schema = Arc::new(Schema::empty());
244
245        let mut options = RecordBatchOptions::default();
246        options.row_count = Some(100);
247        let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
248        // put in 2 batches of 100 rows each
249        let re = concat_batches(&schema, &[batch.clone(), batch]).unwrap();
250
251        assert_eq!(re.num_rows(), 200);
252    }
253
254    #[test]
255    fn test_concat_one_element_vec() {
256        let arr = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
257            Some(-1),
258            Some(2),
259            None,
260        ])) as ArrayRef;
261        let result = concat(&[arr.as_ref()]).unwrap();
262        assert_eq!(
263            &arr, &result,
264            "concatenating single element array gives back the same result"
265        );
266    }
267
268    #[test]
269    fn test_concat_incompatible_datatypes() {
270        let re = concat(&[
271            &PrimitiveArray::<Int64Type>::from(vec![Some(-1), Some(2), None]),
272            &StringArray::from(vec![Some("hello"), Some("bar"), Some("world")]),
273        ]);
274        assert!(re.is_err());
275    }
276
277    #[test]
278    fn test_concat_string_arrays() {
279        let arr = concat(&[
280            &StringArray::from(vec!["hello", "world"]),
281            &StringArray::from(vec!["2", "3", "4"]),
282            &StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]),
283        ])
284        .unwrap();
285
286        let expected_output = Arc::new(StringArray::from(vec![
287            Some("hello"),
288            Some("world"),
289            Some("2"),
290            Some("3"),
291            Some("4"),
292            Some("foo"),
293            Some("bar"),
294            None,
295            Some("baz"),
296        ])) as ArrayRef;
297
298        assert_eq!(&arr, &expected_output);
299    }
300
301    #[test]
302    fn test_concat_primitive_arrays() {
303        let arr = concat(&[
304            &PrimitiveArray::<Int64Type>::from(vec![Some(-1), Some(-1), Some(2), None, None]),
305            &PrimitiveArray::<Int64Type>::from(vec![Some(101), Some(102), Some(103), None]),
306            &PrimitiveArray::<Int64Type>::from(vec![Some(256), Some(512), Some(1024)]),
307        ])
308        .unwrap();
309
310        let expected_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
311            Some(-1),
312            Some(-1),
313            Some(2),
314            None,
315            None,
316            Some(101),
317            Some(102),
318            Some(103),
319            None,
320            Some(256),
321            Some(512),
322            Some(1024),
323        ])) as ArrayRef;
324
325        assert_eq!(&arr, &expected_output);
326    }
327
328    #[test]
329    fn test_concat_primitive_array_slices() {
330        let input_1 =
331            PrimitiveArray::<Int64Type>::from(vec![Some(-1), Some(-1), Some(2), None, None])
332                .slice(1, 3);
333
334        let input_2 =
335            PrimitiveArray::<Int64Type>::from(vec![Some(101), Some(102), Some(103), None])
336                .slice(1, 3);
337        let arr = concat(&[&input_1, &input_2]).unwrap();
338
339        let expected_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
340            Some(-1),
341            Some(2),
342            None,
343            Some(102),
344            Some(103),
345            None,
346        ])) as ArrayRef;
347
348        assert_eq!(&arr, &expected_output);
349    }
350
351    #[test]
352    fn test_concat_boolean_primitive_arrays() {
353        let arr = concat(&[
354            &BooleanArray::from(vec![
355                Some(true),
356                Some(true),
357                Some(false),
358                None,
359                None,
360                Some(false),
361            ]),
362            &BooleanArray::from(vec![None, Some(false), Some(true), Some(false)]),
363        ])
364        .unwrap();
365
366        let expected_output = Arc::new(BooleanArray::from(vec![
367            Some(true),
368            Some(true),
369            Some(false),
370            None,
371            None,
372            Some(false),
373            None,
374            Some(false),
375            Some(true),
376            Some(false),
377        ])) as ArrayRef;
378
379        assert_eq!(&arr, &expected_output);
380    }
381
382    #[test]
383    fn test_concat_primitive_list_arrays() {
384        let list1 = vec![
385            Some(vec![Some(-1), Some(-1), Some(2), None, None]),
386            Some(vec![]),
387            None,
388            Some(vec![Some(10)]),
389        ];
390        let list1_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list1.clone());
391
392        let list2 = vec![
393            None,
394            Some(vec![Some(100), None, Some(101)]),
395            Some(vec![Some(102)]),
396        ];
397        let list2_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list2.clone());
398
399        let list3 = vec![Some(vec![Some(1000), Some(1001)])];
400        let list3_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list3.clone());
401
402        let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap();
403
404        let expected = list1.into_iter().chain(list2).chain(list3);
405        let array_expected = ListArray::from_iter_primitive::<Int64Type, _, _>(expected);
406
407        assert_eq!(array_result.as_ref(), &array_expected as &dyn Array);
408    }
409
410    #[test]
411    fn test_concat_primitive_fixed_size_list_arrays() {
412        let list1 = vec![
413            Some(vec![Some(-1), None]),
414            None,
415            Some(vec![Some(10), Some(20)]),
416        ];
417        let list1_array =
418            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(list1.clone(), 2);
419
420        let list2 = vec![
421            None,
422            Some(vec![Some(100), None]),
423            Some(vec![Some(102), Some(103)]),
424        ];
425        let list2_array =
426            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(list2.clone(), 2);
427
428        let list3 = vec![Some(vec![Some(1000), Some(1001)])];
429        let list3_array =
430            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(list3.clone(), 2);
431
432        let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap();
433
434        let expected = list1.into_iter().chain(list2).chain(list3);
435        let array_expected =
436            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(expected, 2);
437
438        assert_eq!(array_result.as_ref(), &array_expected as &dyn Array);
439    }
440
441    #[test]
442    fn test_concat_struct_arrays() {
443        let field = Arc::new(Field::new("field", DataType::Int64, true));
444        let input_primitive_1: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
445            Some(-1),
446            Some(-1),
447            Some(2),
448            None,
449            None,
450        ]));
451        let input_struct_1 = StructArray::from(vec![(field.clone(), input_primitive_1)]);
452
453        let input_primitive_2: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
454            Some(101),
455            Some(102),
456            Some(103),
457            None,
458        ]));
459        let input_struct_2 = StructArray::from(vec![(field.clone(), input_primitive_2)]);
460
461        let input_primitive_3: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
462            Some(256),
463            Some(512),
464            Some(1024),
465        ]));
466        let input_struct_3 = StructArray::from(vec![(field, input_primitive_3)]);
467
468        let arr = concat(&[&input_struct_1, &input_struct_2, &input_struct_3]).unwrap();
469
470        let expected_primitive_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
471            Some(-1),
472            Some(-1),
473            Some(2),
474            None,
475            None,
476            Some(101),
477            Some(102),
478            Some(103),
479            None,
480            Some(256),
481            Some(512),
482            Some(1024),
483        ])) as ArrayRef;
484
485        let actual_primitive = arr
486            .as_any()
487            .downcast_ref::<StructArray>()
488            .unwrap()
489            .column(0);
490        assert_eq!(actual_primitive, &expected_primitive_output);
491    }
492
493    #[test]
494    fn test_concat_struct_array_slices() {
495        let field = Arc::new(Field::new("field", DataType::Int64, true));
496        let input_primitive_1: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
497            Some(-1),
498            Some(-1),
499            Some(2),
500            None,
501            None,
502        ]));
503        let input_struct_1 = StructArray::from(vec![(field.clone(), input_primitive_1)]);
504
505        let input_primitive_2: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
506            Some(101),
507            Some(102),
508            Some(103),
509            None,
510        ]));
511        let input_struct_2 = StructArray::from(vec![(field, input_primitive_2)]);
512
513        let arr = concat(&[&input_struct_1.slice(1, 3), &input_struct_2.slice(1, 2)]).unwrap();
514
515        let expected_primitive_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
516            Some(-1),
517            Some(2),
518            None,
519            Some(102),
520            Some(103),
521        ])) as ArrayRef;
522
523        let actual_primitive = arr
524            .as_any()
525            .downcast_ref::<StructArray>()
526            .unwrap()
527            .column(0);
528        assert_eq!(actual_primitive, &expected_primitive_output);
529    }
530
531    #[test]
532    fn test_string_array_slices() {
533        let input_1 = StringArray::from(vec!["hello", "A", "B", "C"]);
534        let input_2 = StringArray::from(vec!["world", "D", "E", "Z"]);
535
536        let arr = concat(&[&input_1.slice(1, 3), &input_2.slice(1, 2)]).unwrap();
537
538        let expected_output = StringArray::from(vec!["A", "B", "C", "D", "E"]);
539
540        let actual_output = arr.as_any().downcast_ref::<StringArray>().unwrap();
541        assert_eq!(actual_output, &expected_output);
542    }
543
544    #[test]
545    fn test_string_array_with_null_slices() {
546        let input_1 = StringArray::from(vec![Some("hello"), None, Some("A"), Some("C")]);
547        let input_2 = StringArray::from(vec![None, Some("world"), Some("D"), None]);
548
549        let arr = concat(&[&input_1.slice(1, 3), &input_2.slice(1, 2)]).unwrap();
550
551        let expected_output =
552            StringArray::from(vec![None, Some("A"), Some("C"), Some("world"), Some("D")]);
553
554        let actual_output = arr.as_any().downcast_ref::<StringArray>().unwrap();
555        assert_eq!(actual_output, &expected_output);
556    }
557
558    fn collect_string_dictionary(array: &DictionaryArray<Int32Type>) -> Vec<Option<&str>> {
559        let concrete = array.downcast_dict::<StringArray>().unwrap();
560        concrete.into_iter().collect()
561    }
562
563    #[test]
564    fn test_string_dictionary_array() {
565        let input_1: DictionaryArray<Int32Type> = vec!["hello", "A", "B", "hello", "hello", "C"]
566            .into_iter()
567            .collect();
568        let input_2: DictionaryArray<Int32Type> = vec!["hello", "E", "E", "hello", "F", "E"]
569            .into_iter()
570            .collect();
571
572        let expected: Vec<_> = vec![
573            "hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F", "E",
574        ]
575        .into_iter()
576        .map(Some)
577        .collect();
578
579        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
580        let dictionary = concat.as_dictionary::<Int32Type>();
581        let actual = collect_string_dictionary(dictionary);
582        assert_eq!(actual, expected);
583
584        // Should have concatenated inputs together
585        assert_eq!(
586            dictionary.values().len(),
587            input_1.values().len() + input_2.values().len(),
588        )
589    }
590
591    #[test]
592    fn test_string_dictionary_array_nulls() {
593        let input_1: DictionaryArray<Int32Type> = vec![Some("foo"), Some("bar"), None, Some("fiz")]
594            .into_iter()
595            .collect();
596        let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
597        let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None];
598
599        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
600        let dictionary = concat.as_dictionary::<Int32Type>();
601        let actual = collect_string_dictionary(dictionary);
602        assert_eq!(actual, expected);
603
604        // Should have concatenated inputs together
605        assert_eq!(
606            dictionary.values().len(),
607            input_1.values().len() + input_2.values().len(),
608        )
609    }
610
611    #[test]
612    fn test_string_dictionary_merge() {
613        let mut builder = StringDictionaryBuilder::<Int32Type>::new();
614        for i in 0..20 {
615            builder.append(i.to_string()).unwrap();
616        }
617        let input_1 = builder.finish();
618
619        let mut builder = StringDictionaryBuilder::<Int32Type>::new();
620        for i in 0..30 {
621            builder.append(i.to_string()).unwrap();
622        }
623        let input_2 = builder.finish();
624
625        let expected: Vec<_> = (0..20).chain(0..30).map(|x| x.to_string()).collect();
626        let expected: Vec<_> = expected.iter().map(|x| Some(x.as_str())).collect();
627
628        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
629        let dictionary = concat.as_dictionary::<Int32Type>();
630        let actual = collect_string_dictionary(dictionary);
631        assert_eq!(actual, expected);
632
633        // Should have merged inputs together
634        // Not 30 as this is done on a best-effort basis
635        let values_len = dictionary.values().len();
636        assert!((30..40).contains(&values_len), "{values_len}")
637    }
638
639    #[test]
640    fn test_concat_string_sizes() {
641        let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
642        let b: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
643        let c = LargeStringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]);
644        // 150 * 3 = 450
645        // 150 * 3 = 450
646        // 3 * 3   = 9
647        // ------------+
648        // 909
649        // closest 64 byte aligned cap = 960
650
651        let arr = concat(&[&a, &b, &c]).unwrap();
652        // this would have been 1280 if we did not precompute the value lengths.
653        assert_eq!(arr.to_data().buffers()[1].capacity(), 960);
654    }
655
656    #[test]
657    fn test_dictionary_concat_reuse() {
658        let array: DictionaryArray<Int8Type> = vec!["a", "a", "b", "c"].into_iter().collect();
659        let copy: DictionaryArray<Int8Type> = array.clone();
660
661        // dictionary is "a", "b", "c"
662        assert_eq!(
663            array.values(),
664            &(Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef)
665        );
666        assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2]));
667
668        // concatenate it with itself
669        let combined = concat(&[&copy as _, &array as _]).unwrap();
670        let combined = combined.as_dictionary::<Int8Type>();
671
672        assert_eq!(
673            combined.values(),
674            &(Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef),
675            "Actual: {combined:#?}"
676        );
677
678        assert_eq!(
679            combined.keys(),
680            &Int8Array::from(vec![0, 0, 1, 2, 0, 0, 1, 2])
681        );
682
683        // Should have reused the dictionary
684        assert!(array
685            .values()
686            .to_data()
687            .ptr_eq(&combined.values().to_data()));
688        assert!(copy.values().to_data().ptr_eq(&combined.values().to_data()));
689
690        let new: DictionaryArray<Int8Type> = vec!["d"].into_iter().collect();
691        let combined = concat(&[&copy as _, &array as _, &new as _]).unwrap();
692        let com = combined.as_dictionary::<Int8Type>();
693
694        // Should not have reused the dictionary
695        assert!(!array.values().to_data().ptr_eq(&com.values().to_data()));
696        assert!(!copy.values().to_data().ptr_eq(&com.values().to_data()));
697        assert!(!new.values().to_data().ptr_eq(&com.values().to_data()));
698    }
699
700    #[test]
701    fn concat_record_batches() {
702        let schema = Arc::new(Schema::new(vec![
703            Field::new("a", DataType::Int32, false),
704            Field::new("b", DataType::Utf8, false),
705        ]));
706        let batch1 = RecordBatch::try_new(
707            schema.clone(),
708            vec![
709                Arc::new(Int32Array::from(vec![1, 2])),
710                Arc::new(StringArray::from(vec!["a", "b"])),
711            ],
712        )
713        .unwrap();
714        let batch2 = RecordBatch::try_new(
715            schema.clone(),
716            vec![
717                Arc::new(Int32Array::from(vec![3, 4])),
718                Arc::new(StringArray::from(vec!["c", "d"])),
719            ],
720        )
721        .unwrap();
722        let new_batch = concat_batches(&schema, [&batch1, &batch2]).unwrap();
723        assert_eq!(new_batch.schema().as_ref(), schema.as_ref());
724        assert_eq!(2, new_batch.num_columns());
725        assert_eq!(4, new_batch.num_rows());
726        let new_batch_owned = concat_batches(&schema, &[batch1, batch2]).unwrap();
727        assert_eq!(new_batch_owned.schema().as_ref(), schema.as_ref());
728        assert_eq!(2, new_batch_owned.num_columns());
729        assert_eq!(4, new_batch_owned.num_rows());
730    }
731
732    #[test]
733    fn concat_empty_record_batch() {
734        let schema = Arc::new(Schema::new(vec![
735            Field::new("a", DataType::Int32, false),
736            Field::new("b", DataType::Utf8, false),
737        ]));
738        let batch = concat_batches(&schema, []).unwrap();
739        assert_eq!(batch.schema().as_ref(), schema.as_ref());
740        assert_eq!(0, batch.num_rows());
741    }
742
743    #[test]
744    fn concat_record_batches_of_different_schemas_but_compatible_data() {
745        let schema1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
746        // column names differ
747        let schema2 = Arc::new(Schema::new(vec![Field::new("c", DataType::Int32, false)]));
748        let batch1 = RecordBatch::try_new(
749            schema1.clone(),
750            vec![Arc::new(Int32Array::from(vec![1, 2]))],
751        )
752        .unwrap();
753        let batch2 =
754            RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3, 4]))]).unwrap();
755        // concat_batches simply uses the schema provided
756        let batch = concat_batches(&schema1, [&batch1, &batch2]).unwrap();
757        assert_eq!(batch.schema().as_ref(), schema1.as_ref());
758        assert_eq!(4, batch.num_rows());
759    }
760
761    #[test]
762    fn concat_record_batches_of_different_schemas_incompatible_data() {
763        let schema1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
764        // column names differ
765        let schema2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)]));
766        let batch1 = RecordBatch::try_new(
767            schema1.clone(),
768            vec![Arc::new(Int32Array::from(vec![1, 2]))],
769        )
770        .unwrap();
771        let batch2 = RecordBatch::try_new(
772            schema2,
773            vec![Arc::new(StringArray::from(vec!["foo", "bar"]))],
774        )
775        .unwrap();
776
777        let error = concat_batches(&schema1, [&batch1, &batch2]).unwrap_err();
778        assert_eq!(error.to_string(), "Invalid argument error: It is not possible to concatenate arrays of different data types.");
779    }
780
781    #[test]
782    fn concat_capacity() {
783        let a = Int32Array::from_iter_values(0..100);
784        let b = Int32Array::from_iter_values(10..20);
785        let a = concat(&[&a, &b]).unwrap();
786        let data = a.to_data();
787        assert_eq!(data.buffers()[0].len(), 440);
788        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
789
790        let a = concat(&[&a.slice(10, 20), &b]).unwrap();
791        let data = a.to_data();
792        assert_eq!(data.buffers()[0].len(), 120);
793        assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64
794
795        let a = StringArray::from_iter_values(std::iter::repeat("foo").take(100));
796        let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]);
797
798        let a = concat(&[&a, &b]).unwrap();
799        let data = a.to_data();
800        // (100 + 4 + 1) * size_of<i32>()
801        assert_eq!(data.buffers()[0].len(), 420);
802        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
803
804        // len("foo") * 100 + len("bingo") + len("bongo") + len("lorem")
805        assert_eq!(data.buffers()[1].len(), 315);
806        assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64
807
808        let a = concat(&[&a.slice(10, 40), &b]).unwrap();
809        let data = a.to_data();
810        // (40 + 4 + 5) * size_of<i32>()
811        assert_eq!(data.buffers()[0].len(), 180);
812        assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64
813
814        // len("foo") * 40 + len("bingo") + len("bongo") + len("lorem")
815        assert_eq!(data.buffers()[1].len(), 135);
816        assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64
817
818        let a = LargeBinaryArray::from_iter_values(std::iter::repeat(b"foo").take(100));
819        let b = LargeBinaryArray::from_iter_values(std::iter::repeat(b"cupcakes").take(10));
820
821        let a = concat(&[&a, &b]).unwrap();
822        let data = a.to_data();
823        // (100 + 10 + 1) * size_of<i64>()
824        assert_eq!(data.buffers()[0].len(), 888);
825        assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64
826
827        // len("foo") * 100 + len("cupcakes") * 10
828        assert_eq!(data.buffers()[1].len(), 380);
829        assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64
830
831        let a = concat(&[&a.slice(10, 40), &b]).unwrap();
832        let data = a.to_data();
833        // (40 + 10 + 1) * size_of<i64>()
834        assert_eq!(data.buffers()[0].len(), 408);
835        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
836
837        // len("foo") * 40 + len("cupcakes") * 10
838        assert_eq!(data.buffers()[1].len(), 200);
839        assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
840    }
841
842    #[test]
843    fn concat_sparse_nulls() {
844        let values = StringArray::from_iter_values((0..100).map(|x| x.to_string()));
845        let keys = Int32Array::from(vec![1; 10]);
846        let dict_a = DictionaryArray::new(keys, Arc::new(values));
847        let values = StringArray::new_null(0);
848        let keys = Int32Array::new_null(10);
849        let dict_b = DictionaryArray::new(keys, Arc::new(values));
850        let array = concat(&[&dict_a, &dict_b]).unwrap();
851        assert_eq!(array.null_count(), 10);
852        assert_eq!(array.logical_null_count(), 10);
853    }
854}