arrow_select/
interleave.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Interleave elements from multiple arrays
19
20use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
21use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder};
22use arrow_array::cast::AsArray;
23use arrow_array::types::*;
24use arrow_array::*;
25use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer, NullBufferBuilder, OffsetBuffer};
26use arrow_data::transform::MutableArrayData;
27use arrow_schema::{ArrowError, DataType};
28use std::sync::Arc;
29
30macro_rules! primitive_helper {
31    ($t:ty, $values:ident, $indices:ident, $data_type:ident) => {
32        interleave_primitive::<$t>($values, $indices, $data_type)
33    };
34}
35
36macro_rules! dict_helper {
37    ($t:ty, $values:expr, $indices:expr) => {
38        Ok(Arc::new(interleave_dictionaries::<$t>($values, $indices)?) as _)
39    };
40}
41
42///
43/// Takes elements by index from a list of [`Array`], creating a new [`Array`] from those values.
44///
45/// Each element in `indices` is a pair of `usize` with the first identifying the index
46/// of the [`Array`] in `values`, and the second the index of the value within that [`Array`]
47///
48/// ```text
49/// ┌─────────────────┐      ┌─────────┐                                  ┌─────────────────┐
50/// │        A        │      │ (0, 0)  │        interleave(               │        A        │
51/// ├─────────────────┤      ├─────────┤          [values0, values1],     ├─────────────────┤
52/// │        D        │      │ (1, 0)  │          indices                 │        B        │
53/// └─────────────────┘      ├─────────┤        )                         ├─────────────────┤
54///   values array 0         │ (1, 1)  │      ─────────────────────────▶  │        C        │
55///                          ├─────────┤                                  ├─────────────────┤
56///                          │ (0, 1)  │                                  │        D        │
57///                          └─────────┘                                  └─────────────────┘
58/// ┌─────────────────┐       indices
59/// │        B        │        array
60/// ├─────────────────┤                                                    result
61/// │        C        │
62/// ├─────────────────┤
63/// │        E        │
64/// └─────────────────┘
65///   values array 1
66/// ```
67///
68/// For selecting values by index from a single array see [`crate::take`]
69pub fn interleave(
70    values: &[&dyn Array],
71    indices: &[(usize, usize)],
72) -> Result<ArrayRef, ArrowError> {
73    if values.is_empty() {
74        return Err(ArrowError::InvalidArgumentError(
75            "interleave requires input of at least one array".to_string(),
76        ));
77    }
78    let data_type = values[0].data_type();
79
80    for array in values.iter().skip(1) {
81        if array.data_type() != data_type {
82            return Err(ArrowError::InvalidArgumentError(format!(
83                "It is not possible to interleave arrays of different data types ({} and {})",
84                data_type,
85                array.data_type()
86            )));
87        }
88    }
89
90    if indices.is_empty() {
91        return Ok(new_empty_array(data_type));
92    }
93
94    downcast_primitive! {
95        data_type => (primitive_helper, values, indices, data_type),
96        DataType::Utf8 => interleave_bytes::<Utf8Type>(values, indices),
97        DataType::LargeUtf8 => interleave_bytes::<LargeUtf8Type>(values, indices),
98        DataType::Binary => interleave_bytes::<BinaryType>(values, indices),
99        DataType::LargeBinary => interleave_bytes::<LargeBinaryType>(values, indices),
100        DataType::Dictionary(k, _) => downcast_integer! {
101            k.as_ref() => (dict_helper, values, indices),
102            _ => unreachable!("illegal dictionary key type {k}")
103        },
104        _ => interleave_fallback(values, indices)
105    }
106}
107
108/// Common functionality for interleaving arrays
109///
110/// T is the concrete Array type
111struct Interleave<'a, T> {
112    /// The input arrays downcast to T
113    arrays: Vec<&'a T>,
114    /// The null buffer of the interleaved output
115    nulls: Option<NullBuffer>,
116}
117
118impl<'a, T: Array + 'static> Interleave<'a, T> {
119    fn new(values: &[&'a dyn Array], indices: &'a [(usize, usize)]) -> Self {
120        let mut has_nulls = false;
121        let arrays: Vec<&T> = values
122            .iter()
123            .map(|x| {
124                has_nulls = has_nulls || x.null_count() != 0;
125                x.as_any().downcast_ref().unwrap()
126            })
127            .collect();
128
129        let nulls = match has_nulls {
130            true => {
131                let mut builder = NullBufferBuilder::new(indices.len());
132                for (a, b) in indices {
133                    let v = arrays[*a].is_valid(*b);
134                    builder.append(v)
135                }
136                builder.finish()
137            }
138            false => None,
139        };
140
141        Self { arrays, nulls }
142    }
143}
144
145fn interleave_primitive<T: ArrowPrimitiveType>(
146    values: &[&dyn Array],
147    indices: &[(usize, usize)],
148    data_type: &DataType,
149) -> Result<ArrayRef, ArrowError> {
150    let interleaved = Interleave::<'_, PrimitiveArray<T>>::new(values, indices);
151
152    let mut values = Vec::with_capacity(indices.len());
153    for (a, b) in indices {
154        let v = interleaved.arrays[*a].value(*b);
155        values.push(v)
156    }
157
158    let array = PrimitiveArray::<T>::new(values.into(), interleaved.nulls);
159    Ok(Arc::new(array.with_data_type(data_type.clone())))
160}
161
162fn interleave_bytes<T: ByteArrayType>(
163    values: &[&dyn Array],
164    indices: &[(usize, usize)],
165) -> Result<ArrayRef, ArrowError> {
166    let interleaved = Interleave::<'_, GenericByteArray<T>>::new(values, indices);
167
168    let mut capacity = 0;
169    let mut offsets = BufferBuilder::<T::Offset>::new(indices.len() + 1);
170    offsets.append(T::Offset::from_usize(0).unwrap());
171    for (a, b) in indices {
172        let o = interleaved.arrays[*a].value_offsets();
173        let element_len = o[*b + 1].as_usize() - o[*b].as_usize();
174        capacity += element_len;
175        offsets.append(T::Offset::from_usize(capacity).expect("overflow"));
176    }
177
178    let mut values = MutableBuffer::new(capacity);
179    for (a, b) in indices {
180        values.extend_from_slice(interleaved.arrays[*a].value(*b).as_ref());
181    }
182
183    // Safety: safe by construction
184    let array = unsafe {
185        let offsets = OffsetBuffer::new_unchecked(offsets.finish().into());
186        GenericByteArray::<T>::new_unchecked(offsets, values.into(), interleaved.nulls)
187    };
188    Ok(Arc::new(array))
189}
190
191fn interleave_dictionaries<K: ArrowDictionaryKeyType>(
192    arrays: &[&dyn Array],
193    indices: &[(usize, usize)],
194) -> Result<ArrayRef, ArrowError> {
195    let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::<K>()).collect();
196    if !should_merge_dictionary_values::<K>(&dictionaries, indices.len()) {
197        return interleave_fallback(arrays, indices);
198    }
199
200    let masks: Vec<_> = dictionaries
201        .iter()
202        .enumerate()
203        .map(|(a_idx, dictionary)| {
204            let mut key_mask = BooleanBufferBuilder::new_from_buffer(
205                MutableBuffer::new_null(dictionary.len()),
206                dictionary.len(),
207            );
208
209            for (_, key_idx) in indices.iter().filter(|(a, _)| *a == a_idx) {
210                key_mask.set_bit(*key_idx, true);
211            }
212            key_mask.finish()
213        })
214        .collect();
215
216    let merged = merge_dictionary_values(&dictionaries, Some(&masks))?;
217
218    // Recompute keys
219    let mut keys = PrimitiveBuilder::<K>::with_capacity(indices.len());
220    for (a, b) in indices {
221        let old_keys: &PrimitiveArray<K> = dictionaries[*a].keys();
222        match old_keys.is_valid(*b) {
223            true => {
224                let old_key = old_keys.values()[*b];
225                keys.append_value(merged.key_mappings[*a][old_key.as_usize()])
226            }
227            false => keys.append_null(),
228        }
229    }
230    let array = unsafe { DictionaryArray::new_unchecked(keys.finish(), merged.values) };
231    Ok(Arc::new(array))
232}
233
234/// Fallback implementation of interleave using [`MutableArrayData`]
235fn interleave_fallback(
236    values: &[&dyn Array],
237    indices: &[(usize, usize)],
238) -> Result<ArrayRef, ArrowError> {
239    let arrays: Vec<_> = values.iter().map(|x| x.to_data()).collect();
240    let arrays: Vec<_> = arrays.iter().collect();
241    let mut array_data = MutableArrayData::new(arrays, false, indices.len());
242
243    let mut cur_array = indices[0].0;
244    let mut start_row_idx = indices[0].1;
245    let mut end_row_idx = start_row_idx + 1;
246
247    for (array, row) in indices.iter().skip(1).copied() {
248        if array == cur_array && row == end_row_idx {
249            // subsequent row in same batch
250            end_row_idx += 1;
251            continue;
252        }
253
254        // emit current batch of rows for current buffer
255        array_data.extend(cur_array, start_row_idx, end_row_idx);
256
257        // start new batch of rows
258        cur_array = array;
259        start_row_idx = row;
260        end_row_idx = start_row_idx + 1;
261    }
262
263    // emit final batch of rows
264    array_data.extend(cur_array, start_row_idx, end_row_idx);
265    Ok(make_array(array_data.freeze()))
266}
267
268/// Interleave rows by index from multiple [`RecordBatch`] instances and return a new [`RecordBatch`].
269///
270/// This function will call [`interleave`] on each array of the [`RecordBatch`] instances and assemble a new [`RecordBatch`].
271///
272/// # Example
273/// ```
274/// # use std::sync::Arc;
275/// # use arrow_array::{StringArray, Int32Array, RecordBatch, UInt32Array};
276/// # use arrow_schema::{DataType, Field, Schema};
277/// # use arrow_select::interleave::interleave_record_batch;
278///
279/// let schema = Arc::new(Schema::new(vec![
280///     Field::new("a", DataType::Int32, true),
281///     Field::new("b", DataType::Utf8, true),
282/// ]));
283///
284/// let batch1 = RecordBatch::try_new(
285///     schema.clone(),
286///     vec![
287///         Arc::new(Int32Array::from(vec![0, 1, 2])),
288///         Arc::new(StringArray::from(vec!["a", "b", "c"])),
289///     ],
290/// ).unwrap();
291///
292/// let batch2 = RecordBatch::try_new(
293///     schema.clone(),
294///     vec![
295///         Arc::new(Int32Array::from(vec![3, 4, 5])),
296///         Arc::new(StringArray::from(vec!["d", "e", "f"])),
297///     ],
298/// ).unwrap();
299///
300/// let indices = vec![(0, 1), (1, 2), (0, 0), (1, 1)];
301/// let interleaved = interleave_record_batch(&[&batch1, &batch2], &indices).unwrap();
302///
303/// let expected = RecordBatch::try_new(
304///     schema,
305///     vec![
306///         Arc::new(Int32Array::from(vec![1, 5, 0, 4])),
307///         Arc::new(StringArray::from(vec!["b", "f", "a", "e"])),
308///     ],
309/// ).unwrap();
310/// assert_eq!(interleaved, expected);
311/// ```
312pub fn interleave_record_batch(
313    record_batches: &[&RecordBatch],
314    indices: &[(usize, usize)],
315) -> Result<RecordBatch, ArrowError> {
316    let schema = record_batches[0].schema();
317    let columns = (0..schema.fields().len())
318        .map(|i| {
319            let column_values: Vec<&dyn Array> = record_batches
320                .iter()
321                .map(|batch| batch.column(i).as_ref())
322                .collect();
323            interleave(&column_values, indices)
324        })
325        .collect::<Result<Vec<_>, _>>()?;
326    RecordBatch::try_new(schema, columns)
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use arrow_array::builder::{Int32Builder, ListBuilder};
333
334    #[test]
335    fn test_primitive() {
336        let a = Int32Array::from_iter_values([1, 2, 3, 4]);
337        let b = Int32Array::from_iter_values([5, 6, 7]);
338        let c = Int32Array::from_iter_values([8, 9, 10]);
339        let values = interleave(&[&a, &b, &c], &[(0, 3), (0, 3), (2, 2), (2, 0), (1, 1)]).unwrap();
340        let v = values.as_primitive::<Int32Type>();
341        assert_eq!(v.values(), &[4, 4, 10, 8, 6]);
342    }
343
344    #[test]
345    fn test_primitive_nulls() {
346        let a = Int32Array::from_iter_values([1, 2, 3, 4]);
347        let b = Int32Array::from_iter([Some(1), Some(4), None]);
348        let values = interleave(&[&a, &b], &[(0, 1), (1, 2), (1, 2), (0, 3), (0, 2)]).unwrap();
349        let v: Vec<_> = values.as_primitive::<Int32Type>().into_iter().collect();
350        assert_eq!(&v, &[Some(2), None, None, Some(4), Some(3)])
351    }
352
353    #[test]
354    fn test_primitive_empty() {
355        let a = Int32Array::from_iter_values([1, 2, 3, 4]);
356        let v = interleave(&[&a], &[]).unwrap();
357        assert!(v.is_empty());
358        assert_eq!(v.data_type(), &DataType::Int32);
359    }
360
361    #[test]
362    fn test_strings() {
363        let a = StringArray::from_iter_values(["a", "b", "c"]);
364        let b = StringArray::from_iter_values(["hello", "world", "foo"]);
365        let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]).unwrap();
366        let v = values.as_string::<i32>();
367        let values: Vec<_> = v.into_iter().collect();
368        assert_eq!(
369            &values,
370            &[
371                Some("c"),
372                Some("c"),
373                Some("hello"),
374                Some("world"),
375                Some("b")
376            ]
377        )
378    }
379
380    #[test]
381    fn test_interleave_dictionary() {
382        let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "a", "b"]);
383        let b = DictionaryArray::<Int32Type>::from_iter(["a", "c", "a", "c", "a"]);
384
385        // Should not recompute dictionary
386        let values =
387            interleave(&[&a, &b], &[(0, 2), (0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]).unwrap();
388        let v = values.as_dictionary::<Int32Type>();
389        assert_eq!(v.values().len(), 5);
390
391        let vc = v.downcast_dict::<StringArray>().unwrap();
392        let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect();
393        assert_eq!(&collected, &["c", "c", "c", "a", "c", "b"]);
394
395        // Should recompute dictionary
396        let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 1)]).unwrap();
397        let v = values.as_dictionary::<Int32Type>();
398        assert_eq!(v.values().len(), 1);
399
400        let vc = v.downcast_dict::<StringArray>().unwrap();
401        let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect();
402        assert_eq!(&collected, &["c", "c", "c"]);
403    }
404
405    #[test]
406    fn test_lists() {
407        // [[1, 2], null, [3]]
408        let mut a = ListBuilder::new(Int32Builder::new());
409        a.values().append_value(1);
410        a.values().append_value(2);
411        a.append(true);
412        a.append(false);
413        a.values().append_value(3);
414        a.append(true);
415        let a = a.finish();
416
417        // [[4], null, [5, 6, null]]
418        let mut b = ListBuilder::new(Int32Builder::new());
419        b.values().append_value(4);
420        b.append(true);
421        b.append(false);
422        b.values().append_value(5);
423        b.values().append_value(6);
424        b.values().append_null();
425        b.append(true);
426        let b = b.finish();
427
428        let values = interleave(&[&a, &b], &[(0, 2), (0, 1), (1, 0), (1, 2), (1, 1)]).unwrap();
429        let v = values.as_any().downcast_ref::<ListArray>().unwrap();
430
431        // [[3], null, [4], [5, 6, null], null]
432        let mut expected = ListBuilder::new(Int32Builder::new());
433        expected.values().append_value(3);
434        expected.append(true);
435        expected.append(false);
436        expected.values().append_value(4);
437        expected.append(true);
438        expected.values().append_value(5);
439        expected.values().append_value(6);
440        expected.values().append_null();
441        expected.append(true);
442        expected.append(false);
443        let expected = expected.finish();
444
445        assert_eq!(v, &expected);
446    }
447
448    #[test]
449    fn interleave_sparse_nulls() {
450        let values = StringArray::from_iter_values((0..100).map(|x| x.to_string()));
451        let keys = Int32Array::from_iter_values(0..10);
452        let dict_a = DictionaryArray::new(keys, Arc::new(values));
453        let values = StringArray::new_null(0);
454        let keys = Int32Array::new_null(10);
455        let dict_b = DictionaryArray::new(keys, Arc::new(values));
456
457        let indices = &[(0, 0), (0, 1), (0, 2), (1, 0)];
458        let array = interleave(&[&dict_a, &dict_b], indices).unwrap();
459
460        let expected =
461            DictionaryArray::<Int32Type>::from_iter(vec![Some("0"), Some("1"), Some("2"), None]);
462        assert_eq!(array.as_ref(), &expected)
463    }
464}