arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27    bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer,
28};
29use arrow_data::{ArrayData, ArrayDataBuilder};
30use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
31
32use num::{One, Zero};
33
34/// Take elements by index from [Array], creating a new [Array] from those indexes.
35///
36/// ```text
37/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
38/// │        A        │      │    0    │                              │        A        │
39/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
40/// │        D        │      │    2    │                              │        B        │
41/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
42/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
43/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
44/// │        C        │      │    1    │                              │        D        │
45/// ├─────────────────┤      └─────────┘                              └─────────────────┘
46/// │        E        │
47/// └─────────────────┘
48///    values array          indices array                              result
49/// ```
50///
51/// For selecting values by index from multiple arrays see [`crate::interleave`]
52///
53/// Note that this kernel, similar to other kernels in this crate,
54/// will avoid allocating where not necessary. Consequently
55/// the returned array may share buffers with the inputs
56///
57/// # Errors
58/// This function errors whenever:
59/// * An index cannot be casted to `usize` (typically 32 bit architectures)
60/// * An index is out of bounds and `options` is set to check bounds.
61///
62/// # Safety
63///
64/// When `options` is not set to check bounds, taking indexes after `len` will panic.
65///
66/// # Examples
67/// ```
68/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
69/// # use arrow_select::take::take;
70/// let values = StringArray::from(vec!["zero", "one", "two"]);
71///
72/// // Take items at index 2, and 1:
73/// let indices = UInt32Array::from(vec![2, 1]);
74/// let taken = take(&values, &indices, None).unwrap();
75/// let taken = taken.as_string::<i32>();
76///
77/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
78/// ```
79pub fn take(
80    values: &dyn Array,
81    indices: &dyn Array,
82    options: Option<TakeOptions>,
83) -> Result<ArrayRef, ArrowError> {
84    let options = options.unwrap_or_default();
85    macro_rules! helper {
86        ($t:ty, $values:expr, $indices:expr, $options:expr) => {{
87            let indices = indices.as_primitive::<$t>();
88            if $options.check_bounds {
89                check_bounds($values.len(), indices)?;
90            }
91            let indices = indices.to_indices();
92            take_impl($values, &indices)
93        }};
94    }
95    downcast_integer! {
96        indices.data_type() => (helper, values, indices, options),
97        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
98    }
99}
100
101/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
102/// [`Vec<ArrayRef>`] from those indices.
103///
104/// ```text
105/// ┌────────┬────────┐
106/// │        │        │           ┌────────┐                                ┌────────┬────────┐
107/// │   A    │   1    │           │        │                                │        │        │
108/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
109/// │        │        │           ├────────┤                                ├────────┼────────┤
110/// │   D    │   4    │           │        │                                │        │        │
111/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
112/// │        │        │           ├────────┤                                ├────────┼────────┤
113/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
114/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
115/// │        │        │           ├────────┤                                ├────────┼────────┤
116/// │   C    │   3    │           │        │                                │        │        │
117/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
118/// │        │        │           └────────┘                                └────────┼────────┘
119/// │   E    │   5    │
120/// └────────┴────────┘
121///    values arrays             indices array                                      result
122/// ```
123///
124/// # Errors
125/// This function errors whenever:
126/// * An index cannot be casted to `usize` (typically 32 bit architectures)
127/// * An index is out of bounds and `options` is set to check bounds.
128///
129/// # Safety
130///
131/// When `options` is not set to check bounds, taking indexes after `len` will panic.
132///
133/// # Examples
134/// ```
135/// # use std::sync::Arc;
136/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
137/// # use arrow_select::take::{take, take_arrays};
138/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
139/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
140///
141/// // Take items at index 2, and 1:
142/// let indices = UInt32Array::from(vec![2, 1]);
143/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
144/// let taken_string = taken_arrays[0].as_string::<i32>();
145/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
146/// let taken_values = taken_arrays[1].as_primitive();
147/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
148/// ```
149pub fn take_arrays(
150    arrays: &[ArrayRef],
151    indices: &dyn Array,
152    options: Option<TakeOptions>,
153) -> Result<Vec<ArrayRef>, ArrowError> {
154    arrays
155        .iter()
156        .map(|array| take(array.as_ref(), indices, options.clone()))
157        .collect()
158}
159
160/// Verifies that the non-null values of `indices` are all `< len`
161fn check_bounds<T: ArrowPrimitiveType>(
162    len: usize,
163    indices: &PrimitiveArray<T>,
164) -> Result<(), ArrowError> {
165    if indices.null_count() > 0 {
166        indices.iter().flatten().try_for_each(|index| {
167            let ix = index
168                .to_usize()
169                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
170            if ix >= len {
171                return Err(ArrowError::ComputeError(format!(
172                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
173                )));
174            }
175            Ok(())
176        })
177    } else {
178        indices.values().iter().try_for_each(|index| {
179            let ix = index
180                .to_usize()
181                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
182            if ix >= len {
183                return Err(ArrowError::ComputeError(format!(
184                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
185                )));
186            }
187            Ok(())
188        })
189    }
190}
191
192#[inline(never)]
193fn take_impl<IndexType: ArrowPrimitiveType>(
194    values: &dyn Array,
195    indices: &PrimitiveArray<IndexType>,
196) -> Result<ArrayRef, ArrowError> {
197    downcast_primitive_array! {
198        values => Ok(Arc::new(take_primitive(values, indices)?)),
199        DataType::Boolean => {
200            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
201            Ok(Arc::new(take_boolean(values, indices)))
202        }
203        DataType::Utf8 => {
204            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
205        }
206        DataType::LargeUtf8 => {
207            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
208        }
209        DataType::Utf8View => {
210            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
211        }
212        DataType::List(_) => {
213            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
214        }
215        DataType::LargeList(_) => {
216            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
217        }
218        DataType::FixedSizeList(_, length) => {
219            let values = values
220                .as_any()
221                .downcast_ref::<FixedSizeListArray>()
222                .unwrap();
223            Ok(Arc::new(take_fixed_size_list(
224                values,
225                indices,
226                *length as u32,
227            )?))
228        }
229        DataType::Map(_, _) => {
230            let list_arr = ListArray::from(values.as_map().clone());
231            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
232            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
233            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
234        }
235        DataType::Struct(fields) => {
236            let array: &StructArray = values.as_struct();
237            let arrays  = array
238                .columns()
239                .iter()
240                .map(|a| take_impl(a.as_ref(), indices))
241                .collect::<Result<Vec<ArrayRef>, _>>()?;
242            let fields: Vec<(FieldRef, ArrayRef)> =
243                fields.iter().cloned().zip(arrays).collect();
244
245            // Create the null bit buffer.
246            let is_valid: Buffer = indices
247                .iter()
248                .map(|index| {
249                    if let Some(index) = index {
250                        array.is_valid(index.to_usize().unwrap())
251                    } else {
252                        false
253                    }
254                })
255                .collect();
256
257            Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
258        }
259        DataType::Dictionary(_, _) => downcast_dictionary_array! {
260            values => Ok(Arc::new(take_dict(values, indices)?)),
261            t => unimplemented!("Take not supported for dictionary type {:?}", t)
262        }
263        DataType::RunEndEncoded(_, _) => downcast_run_array! {
264            values => Ok(Arc::new(take_run(values, indices)?)),
265            t => unimplemented!("Take not supported for run type {:?}", t)
266        }
267        DataType::Binary => {
268            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
269        }
270        DataType::LargeBinary => {
271            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
272        }
273        DataType::BinaryView => {
274            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
275        }
276        DataType::FixedSizeBinary(size) => {
277            let values = values
278                .as_any()
279                .downcast_ref::<FixedSizeBinaryArray>()
280                .unwrap();
281            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
282        }
283        DataType::Null => {
284            // Take applied to a null array produces a null array.
285            if values.len() >= indices.len() {
286                // If the existing null array is as big as the indices, we can use a slice of it
287                // to avoid allocating a new null array.
288                Ok(values.slice(0, indices.len()))
289            } else {
290                // If the existing null array isn't big enough, create a new one.
291                Ok(new_null_array(&DataType::Null, indices.len()))
292            }
293        }
294        DataType::Union(fields, UnionMode::Sparse) => {
295            let mut children = Vec::with_capacity(fields.len());
296            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
297            let type_ids = take_native(values.type_ids(), indices);
298            for (type_id, _field) in fields.iter() {
299                let values = values.child(type_id);
300                let values = take_impl(values, indices)?;
301                children.push(values);
302            }
303            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
304            Ok(Arc::new(array))
305        }
306        DataType::Union(fields, UnionMode::Dense) => {
307            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
308
309            let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
310            let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
311
312            let children = fields.iter()
313                .map(|(field_type_id, _)| {
314                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
315
316                    let indices = crate::filter::filter(&offsets, &mask)?;
317
318                    let values = values.child(field_type_id);
319
320                    take_impl(values, indices.as_primitive::<Int32Type>())
321                })
322                .collect::<Result<_, _>>()?;
323
324            let mut child_offsets = [0; 128];
325
326            let offsets = type_ids.values()
327                .iter()
328                .map(|&i| {
329                    let offset = child_offsets[i as usize];
330
331                    child_offsets[i as usize] += 1;
332
333                    offset
334                })
335                .collect();
336
337            let (_, type_ids, _) = type_ids.into_parts();
338
339            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
340
341            Ok(Arc::new(array))
342        }
343        t => unimplemented!("Take not supported for data type {:?}", t)
344    }
345}
346
347/// Options that define how `take` should behave
348#[derive(Clone, Debug, Default)]
349pub struct TakeOptions {
350    /// Perform bounds check before taking indices from values.
351    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
352    /// If not enabled, and indices exceed bounds, the kernel will panic.
353    pub check_bounds: bool,
354}
355
356#[inline(always)]
357fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
358    index
359        .to_usize()
360        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
361}
362
363/// `take` implementation for all primitive arrays
364///
365/// This checks if an `indices` slot is populated, and gets the value from `values`
366///  as the populated index.
367/// If the `indices` slot is null, a null value is returned.
368/// For example, given:
369///     values:  [1, 2, 3, null, 5]
370///     indices: [0, null, 4, 3]
371/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
372fn take_primitive<T, I>(
373    values: &PrimitiveArray<T>,
374    indices: &PrimitiveArray<I>,
375) -> Result<PrimitiveArray<T>, ArrowError>
376where
377    T: ArrowPrimitiveType,
378    I: ArrowPrimitiveType,
379{
380    let values_buf = take_native(values.values(), indices);
381    let nulls = take_nulls(values.nulls(), indices);
382    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
383}
384
385#[inline(never)]
386fn take_nulls<I: ArrowPrimitiveType>(
387    values: Option<&NullBuffer>,
388    indices: &PrimitiveArray<I>,
389) -> Option<NullBuffer> {
390    match values.filter(|n| n.null_count() > 0) {
391        Some(n) => {
392            let buffer = take_bits(n.inner(), indices);
393            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
394        }
395        None => indices.nulls().cloned(),
396    }
397}
398
399#[inline(never)]
400fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
401    values: &[T],
402    indices: &PrimitiveArray<I>,
403) -> ScalarBuffer<T> {
404    match indices.nulls().filter(|n| n.null_count() > 0) {
405        Some(n) => indices
406            .values()
407            .iter()
408            .enumerate()
409            .map(|(idx, index)| match values.get(index.as_usize()) {
410                Some(v) => *v,
411                None => match n.is_null(idx) {
412                    true => T::default(),
413                    false => panic!("Out-of-bounds index {index:?}"),
414                },
415            })
416            .collect(),
417        None => indices
418            .values()
419            .iter()
420            .map(|index| values[index.as_usize()])
421            .collect(),
422    }
423}
424
425#[inline(never)]
426fn take_bits<I: ArrowPrimitiveType>(
427    values: &BooleanBuffer,
428    indices: &PrimitiveArray<I>,
429) -> BooleanBuffer {
430    let len = indices.len();
431
432    match indices.nulls().filter(|n| n.null_count() > 0) {
433        Some(nulls) => {
434            let mut output_buffer = MutableBuffer::new_null(len);
435            let output_slice = output_buffer.as_slice_mut();
436            nulls.valid_indices().for_each(|idx| {
437                if values.value(indices.value(idx).as_usize()) {
438                    bit_util::set_bit(output_slice, idx);
439                }
440            });
441            BooleanBuffer::new(output_buffer.into(), 0, len)
442        }
443        None => {
444            BooleanBuffer::collect_bool(len, |idx: usize| {
445                // SAFETY: idx<indices.len()
446                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
447            })
448        }
449    }
450}
451
452/// `take` implementation for boolean arrays
453fn take_boolean<IndexType: ArrowPrimitiveType>(
454    values: &BooleanArray,
455    indices: &PrimitiveArray<IndexType>,
456) -> BooleanArray {
457    let val_buf = take_bits(values.values(), indices);
458    let null_buf = take_nulls(values.nulls(), indices);
459    BooleanArray::new(val_buf, null_buf)
460}
461
462/// `take` implementation for string arrays
463fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
464    array: &GenericByteArray<T>,
465    indices: &PrimitiveArray<IndexType>,
466) -> Result<GenericByteArray<T>, ArrowError> {
467    let data_len = indices.len();
468
469    let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
470    let mut offsets = MutableBuffer::new(bytes_offset);
471    offsets.push(T::Offset::default());
472
473    let mut values = MutableBuffer::new(0);
474
475    let nulls;
476    if array.null_count() == 0 && indices.null_count() == 0 {
477        offsets.extend(indices.values().iter().map(|index| {
478            let s: &[u8] = array.value(index.as_usize()).as_ref();
479            values.extend_from_slice(s);
480            T::Offset::usize_as(values.len())
481        }));
482        nulls = None
483    } else if indices.null_count() == 0 {
484        let num_bytes = bit_util::ceil(data_len, 8);
485
486        let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
487        let null_slice = null_buf.as_slice_mut();
488        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
489            let index = index.as_usize();
490            if array.is_valid(index) {
491                let s: &[u8] = array.value(index).as_ref();
492                values.extend_from_slice(s.as_ref());
493            } else {
494                bit_util::unset_bit(null_slice, i);
495            }
496            T::Offset::usize_as(values.len())
497        }));
498        nulls = Some(null_buf.into());
499    } else if array.null_count() == 0 {
500        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
501            if indices.is_valid(i) {
502                let s: &[u8] = array.value(index.as_usize()).as_ref();
503                values.extend_from_slice(s);
504            }
505            T::Offset::usize_as(values.len())
506        }));
507        nulls = indices.nulls().map(|b| b.inner().sliced());
508    } else {
509        let num_bytes = bit_util::ceil(data_len, 8);
510
511        let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
512        let null_slice = null_buf.as_slice_mut();
513        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
514            // check index is valid before using index. The value in
515            // NULL index slots may not be within bounds of array
516            let index = index.as_usize();
517            if indices.is_valid(i) && array.is_valid(index) {
518                let s: &[u8] = array.value(index).as_ref();
519                values.extend_from_slice(s);
520            } else {
521                // set null bit
522                bit_util::unset_bit(null_slice, i);
523            }
524            T::Offset::usize_as(values.len())
525        }));
526        nulls = Some(null_buf.into())
527    }
528
529    T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
530        "Offset overflow for {}BinaryArray: {}",
531        T::Offset::PREFIX,
532        values.len()
533    )))?;
534
535    let array_data = ArrayData::builder(T::DATA_TYPE)
536        .len(data_len)
537        .add_buffer(offsets.into())
538        .add_buffer(values.into())
539        .null_bit_buffer(nulls);
540
541    let array_data = unsafe { array_data.build_unchecked() };
542
543    Ok(GenericByteArray::from(array_data))
544}
545
546/// `take` implementation for byte view arrays
547fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
548    array: &GenericByteViewArray<T>,
549    indices: &PrimitiveArray<IndexType>,
550) -> Result<GenericByteViewArray<T>, ArrowError> {
551    let new_views = take_native(array.views(), indices);
552    let new_nulls = take_nulls(array.nulls(), indices);
553    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
554    Ok(unsafe {
555        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
556    })
557}
558
559/// `take` implementation for list arrays
560///
561/// Calculates the index and indexed offset for the inner array,
562/// applying `take` on the inner array, then reconstructing a list array
563/// with the indexed offsets
564fn take_list<IndexType, OffsetType>(
565    values: &GenericListArray<OffsetType::Native>,
566    indices: &PrimitiveArray<IndexType>,
567) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
568where
569    IndexType: ArrowPrimitiveType,
570    OffsetType: ArrowPrimitiveType,
571    OffsetType::Native: OffsetSizeTrait,
572    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
573{
574    // TODO: Some optimizations can be done here such as if it is
575    // taking the whole list or a contiguous sublist
576    let (list_indices, offsets, null_buf) =
577        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
578
579    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
580    let value_offsets = Buffer::from_vec(offsets);
581    // create a new list with taken data and computed null information
582    let list_data = ArrayDataBuilder::new(values.data_type().clone())
583        .len(indices.len())
584        .null_bit_buffer(Some(null_buf.into()))
585        .offset(0)
586        .add_child_data(taken.into_data())
587        .add_buffer(value_offsets);
588
589    let list_data = unsafe { list_data.build_unchecked() };
590
591    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
592}
593
594/// `take` implementation for `FixedSizeListArray`
595///
596/// Calculates the index and indexed offset for the inner array,
597/// applying `take` on the inner array, then reconstructing a list array
598/// with the indexed offsets
599fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
600    values: &FixedSizeListArray,
601    indices: &PrimitiveArray<IndexType>,
602    length: <UInt32Type as ArrowPrimitiveType>::Native,
603) -> Result<FixedSizeListArray, ArrowError> {
604    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
605    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
606
607    // determine null count and null buffer, which are a function of `values` and `indices`
608    let num_bytes = bit_util::ceil(indices.len(), 8);
609    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
610    let null_slice = null_buf.as_slice_mut();
611
612    for i in 0..indices.len() {
613        let index = indices
614            .value(i)
615            .to_usize()
616            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
617        if !indices.is_valid(i) || values.is_null(index) {
618            bit_util::unset_bit(null_slice, i);
619        }
620    }
621
622    let list_data = ArrayDataBuilder::new(values.data_type().clone())
623        .len(indices.len())
624        .null_bit_buffer(Some(null_buf.into()))
625        .offset(0)
626        .add_child_data(taken.into_data());
627
628    let list_data = unsafe { list_data.build_unchecked() };
629
630    Ok(FixedSizeListArray::from(list_data))
631}
632
633fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
634    values: &FixedSizeBinaryArray,
635    indices: &PrimitiveArray<IndexType>,
636    size: i32,
637) -> Result<FixedSizeBinaryArray, ArrowError> {
638    let nulls = values.nulls();
639    let array_iter = indices
640        .values()
641        .iter()
642        .map(|idx| {
643            let idx = maybe_usize::<IndexType::Native>(*idx)?;
644            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
645                Ok(Some(values.value(idx)))
646            } else {
647                Ok(None)
648            }
649        })
650        .collect::<Result<Vec<_>, ArrowError>>()?
651        .into_iter();
652
653    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
654}
655
656/// `take` implementation for dictionary arrays
657///
658/// applies `take` to the keys of the dictionary array and returns a new dictionary array
659/// with the same dictionary values and reordered keys
660fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
661    values: &DictionaryArray<T>,
662    indices: &PrimitiveArray<I>,
663) -> Result<DictionaryArray<T>, ArrowError> {
664    let new_keys = take_primitive(values.keys(), indices)?;
665    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
666}
667
668/// `take` implementation for run arrays
669///
670/// Finds physical indices for the given logical indices and builds output run array
671/// by taking values in the input run_array.values at the physical indices.
672/// The output run array will be run encoded on the physical indices and not on output values.
673/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
674/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
675/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
676fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
677    run_array: &RunArray<T>,
678    logical_indices: &PrimitiveArray<I>,
679) -> Result<RunArray<T>, ArrowError> {
680    // get physical indices for the input logical indices
681    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
682
683    // Run encode the physical indices into new_run_ends_builder
684    // Keep track of the physical indices to take in take_value_indices
685    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
686    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
687    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
688    let mut new_physical_len = 1;
689    for ix in 1..physical_indices.len() {
690        if physical_indices[ix] != physical_indices[ix - 1] {
691            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
692            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
693            new_physical_len += 1;
694        }
695    }
696    take_value_indices
697        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
698    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
699    let new_run_ends = unsafe {
700        // Safety:
701        // The function builds a valid run_ends array and hence need not be validated.
702        ArrayDataBuilder::new(T::DATA_TYPE)
703            .len(new_physical_len)
704            .null_count(0)
705            .add_buffer(new_run_ends_builder.finish())
706            .build_unchecked()
707    };
708
709    let take_value_indices: PrimitiveArray<I> = unsafe {
710        // Safety:
711        // The function builds a valid take_value_indices array and hence need not be validated.
712        ArrayDataBuilder::new(I::DATA_TYPE)
713            .len(new_physical_len)
714            .null_count(0)
715            .add_buffer(take_value_indices.finish())
716            .build_unchecked()
717            .into()
718    };
719
720    let new_values = take(run_array.values(), &take_value_indices, None)?;
721
722    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
723        .len(physical_indices.len())
724        .add_child_data(new_run_ends)
725        .add_child_data(new_values.into_data());
726    let array_data = unsafe {
727        // Safety:
728        //  This function builds a valid run array and hence can skip validation.
729        builder.build_unchecked()
730    };
731    Ok(array_data.into())
732}
733
734/// Takes/filters a list array's inner data using the offsets of the list array.
735///
736/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
737/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
738/// elements)
739#[allow(clippy::type_complexity)]
740fn take_value_indices_from_list<IndexType, OffsetType>(
741    list: &GenericListArray<OffsetType::Native>,
742    indices: &PrimitiveArray<IndexType>,
743) -> Result<
744    (
745        PrimitiveArray<OffsetType>,
746        Vec<OffsetType::Native>,
747        MutableBuffer,
748    ),
749    ArrowError,
750>
751where
752    IndexType: ArrowPrimitiveType,
753    OffsetType: ArrowPrimitiveType,
754    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
755    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
756{
757    // TODO: benchmark this function, there might be a faster unsafe alternative
758    let offsets: &[OffsetType::Native] = list.value_offsets();
759
760    let mut new_offsets = Vec::with_capacity(indices.len());
761    let mut values = Vec::new();
762    let mut current_offset = OffsetType::Native::zero();
763    // add first offset
764    new_offsets.push(OffsetType::Native::zero());
765
766    // Initialize null buffer
767    let num_bytes = bit_util::ceil(indices.len(), 8);
768    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
769    let null_slice = null_buf.as_slice_mut();
770
771    // compute the value indices, and set offsets accordingly
772    for i in 0..indices.len() {
773        if indices.is_valid(i) {
774            let ix = indices
775                .value(i)
776                .to_usize()
777                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
778            let start = offsets[ix];
779            let end = offsets[ix + 1];
780            current_offset += end - start;
781            new_offsets.push(current_offset);
782
783            let mut curr = start;
784
785            // if start == end, this slot is empty
786            while curr < end {
787                values.push(curr);
788                curr += One::one();
789            }
790            if !list.is_valid(ix) {
791                bit_util::unset_bit(null_slice, i);
792            }
793        } else {
794            bit_util::unset_bit(null_slice, i);
795            new_offsets.push(current_offset);
796        }
797    }
798
799    Ok((
800        PrimitiveArray::<OffsetType>::from(values),
801        new_offsets,
802        null_buf,
803    ))
804}
805
806/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
807fn take_value_indices_from_fixed_size_list<IndexType>(
808    list: &FixedSizeListArray,
809    indices: &PrimitiveArray<IndexType>,
810    length: <UInt32Type as ArrowPrimitiveType>::Native,
811) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
812where
813    IndexType: ArrowPrimitiveType,
814{
815    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
816
817    for i in 0..indices.len() {
818        if indices.is_valid(i) {
819            let index = indices
820                .value(i)
821                .to_usize()
822                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
823            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
824
825            // Safety: Range always has known length.
826            unsafe {
827                values.append_trusted_len_iter(start..start + length);
828            }
829        } else {
830            values.append_nulls(length as usize);
831        }
832    }
833
834    Ok(values.finish())
835}
836
837/// To avoid generating take implementations for every index type, instead we
838/// only generate for UInt32 and UInt64 and coerce inputs to these types
839trait ToIndices {
840    type T: ArrowPrimitiveType;
841
842    fn to_indices(&self) -> PrimitiveArray<Self::T>;
843}
844
845macro_rules! to_indices_reinterpret {
846    ($t:ty, $o:ty) => {
847        impl ToIndices for PrimitiveArray<$t> {
848            type T = $o;
849
850            fn to_indices(&self) -> PrimitiveArray<$o> {
851                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
852                PrimitiveArray::new(cast, self.nulls().cloned())
853            }
854        }
855    };
856}
857
858macro_rules! to_indices_identity {
859    ($t:ty) => {
860        impl ToIndices for PrimitiveArray<$t> {
861            type T = $t;
862
863            fn to_indices(&self) -> PrimitiveArray<$t> {
864                self.clone()
865            }
866        }
867    };
868}
869
870macro_rules! to_indices_widening {
871    ($t:ty, $o:ty) => {
872        impl ToIndices for PrimitiveArray<$t> {
873            type T = UInt32Type;
874
875            fn to_indices(&self) -> PrimitiveArray<$o> {
876                let cast = self.values().iter().copied().map(|x| x as _).collect();
877                PrimitiveArray::new(cast, self.nulls().cloned())
878            }
879        }
880    };
881}
882
883to_indices_widening!(UInt8Type, UInt32Type);
884to_indices_widening!(Int8Type, UInt32Type);
885
886to_indices_widening!(UInt16Type, UInt32Type);
887to_indices_widening!(Int16Type, UInt32Type);
888
889to_indices_identity!(UInt32Type);
890to_indices_reinterpret!(Int32Type, UInt32Type);
891
892to_indices_identity!(UInt64Type);
893to_indices_reinterpret!(Int64Type, UInt64Type);
894
895/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
896///
897/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
898///
899/// # Example
900/// ```
901/// # use std::sync::Arc;
902/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
903/// # use arrow_schema::{DataType, Field, Schema};
904/// # use arrow_select::take::take_record_batch;
905///
906/// let schema = Arc::new(Schema::new(vec![
907///     Field::new("a", DataType::Int32, true),
908///     Field::new("b", DataType::Utf8, true),
909/// ]));
910/// let batch = RecordBatch::try_new(
911///     schema.clone(),
912///     vec![
913///         Arc::new(Int32Array::from_iter_values(0..20)),
914///         Arc::new(StringArray::from_iter_values(
915///             (0..20).map(|i| format!("str-{}", i)),
916///         )),
917///     ],
918/// )
919/// .unwrap();
920///
921/// let indices = UInt32Array::from(vec![1, 5, 10]);
922/// let taken = take_record_batch(&batch, &indices).unwrap();
923///
924/// let expected = RecordBatch::try_new(
925///     schema,
926///     vec![
927///         Arc::new(Int32Array::from(vec![1, 5, 10])),
928///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
929///     ],
930/// )
931/// .unwrap();
932/// assert_eq!(taken, expected);
933/// ```
934pub fn take_record_batch(
935    record_batch: &RecordBatch,
936    indices: &dyn Array,
937) -> Result<RecordBatch, ArrowError> {
938    let columns = record_batch
939        .columns()
940        .iter()
941        .map(|c| take(c, indices, None))
942        .collect::<Result<Vec<_>, _>>()?;
943    RecordBatch::try_new(record_batch.schema(), columns)
944}
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949    use arrow_array::builder::*;
950    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
951    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
952
953    fn test_take_decimal_arrays(
954        data: Vec<Option<i128>>,
955        index: &UInt32Array,
956        options: Option<TakeOptions>,
957        expected_data: Vec<Option<i128>>,
958        precision: &u8,
959        scale: &i8,
960    ) -> Result<(), ArrowError> {
961        let output = data
962            .into_iter()
963            .collect::<Decimal128Array>()
964            .with_precision_and_scale(*precision, *scale)
965            .unwrap();
966
967        let expected = expected_data
968            .into_iter()
969            .collect::<Decimal128Array>()
970            .with_precision_and_scale(*precision, *scale)
971            .unwrap();
972
973        let expected = Arc::new(expected) as ArrayRef;
974        let output = take(&output, index, options).unwrap();
975        assert_eq!(&output, &expected);
976        Ok(())
977    }
978
979    fn test_take_boolean_arrays(
980        data: Vec<Option<bool>>,
981        index: &UInt32Array,
982        options: Option<TakeOptions>,
983        expected_data: Vec<Option<bool>>,
984    ) {
985        let output = BooleanArray::from(data);
986        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
987        let output = take(&output, index, options).unwrap();
988        assert_eq!(&output, &expected)
989    }
990
991    fn test_take_primitive_arrays<T>(
992        data: Vec<Option<T::Native>>,
993        index: &UInt32Array,
994        options: Option<TakeOptions>,
995        expected_data: Vec<Option<T::Native>>,
996    ) -> Result<(), ArrowError>
997    where
998        T: ArrowPrimitiveType,
999        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1000    {
1001        let output = PrimitiveArray::<T>::from(data);
1002        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1003        let output = take(&output, index, options)?;
1004        assert_eq!(&output, &expected);
1005        Ok(())
1006    }
1007
1008    fn test_take_primitive_arrays_non_null<T>(
1009        data: Vec<T::Native>,
1010        index: &UInt32Array,
1011        options: Option<TakeOptions>,
1012        expected_data: Vec<Option<T::Native>>,
1013    ) -> Result<(), ArrowError>
1014    where
1015        T: ArrowPrimitiveType,
1016        PrimitiveArray<T>: From<Vec<T::Native>>,
1017        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1018    {
1019        let output = PrimitiveArray::<T>::from(data);
1020        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1021        let output = take(&output, index, options)?;
1022        assert_eq!(&output, &expected);
1023        Ok(())
1024    }
1025
1026    fn test_take_impl_primitive_arrays<T, I>(
1027        data: Vec<Option<T::Native>>,
1028        index: &PrimitiveArray<I>,
1029        options: Option<TakeOptions>,
1030        expected_data: Vec<Option<T::Native>>,
1031    ) where
1032        T: ArrowPrimitiveType,
1033        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1034        I: ArrowPrimitiveType,
1035    {
1036        let output = PrimitiveArray::<T>::from(data);
1037        let expected = PrimitiveArray::<T>::from(expected_data);
1038        let output = take(&output, index, options).unwrap();
1039        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1040        assert_eq!(output, &expected)
1041    }
1042
1043    // create a simple struct for testing purposes
1044    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1045        let mut struct_builder = StructBuilder::new(
1046            Fields::from(vec![
1047                Field::new("a", DataType::Boolean, true),
1048                Field::new("b", DataType::Int32, true),
1049            ]),
1050            vec![
1051                Box::new(BooleanBuilder::with_capacity(values.len())),
1052                Box::new(Int32Builder::with_capacity(values.len())),
1053            ],
1054        );
1055
1056        for value in values {
1057            struct_builder
1058                .field_builder::<BooleanBuilder>(0)
1059                .unwrap()
1060                .append_option(value.and_then(|v| v.0));
1061            struct_builder
1062                .field_builder::<Int32Builder>(1)
1063                .unwrap()
1064                .append_option(value.and_then(|v| v.1));
1065            struct_builder.append(value.is_some());
1066        }
1067        struct_builder.finish()
1068    }
1069
1070    #[test]
1071    fn test_take_decimal128_non_null_indices() {
1072        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1073        let precision: u8 = 10;
1074        let scale: i8 = 5;
1075        test_take_decimal_arrays(
1076            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1077            &index,
1078            None,
1079            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1080            &precision,
1081            &scale,
1082        )
1083        .unwrap();
1084    }
1085
1086    #[test]
1087    fn test_take_decimal128() {
1088        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1089        let precision: u8 = 10;
1090        let scale: i8 = 5;
1091        test_take_decimal_arrays(
1092            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1093            &index,
1094            None,
1095            vec![Some(3), None, Some(1), Some(3), Some(2)],
1096            &precision,
1097            &scale,
1098        )
1099        .unwrap();
1100    }
1101
1102    #[test]
1103    fn test_take_primitive_non_null_indices() {
1104        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1105        test_take_primitive_arrays::<Int8Type>(
1106            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1107            &index,
1108            None,
1109            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1110        )
1111        .unwrap();
1112    }
1113
1114    #[test]
1115    fn test_take_primitive_non_null_values() {
1116        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1117        test_take_primitive_arrays::<Int8Type>(
1118            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1119            &index,
1120            None,
1121            vec![Some(3), None, Some(1), Some(3), Some(2)],
1122        )
1123        .unwrap();
1124    }
1125
1126    #[test]
1127    fn test_take_primitive_non_null() {
1128        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1129        test_take_primitive_arrays::<Int8Type>(
1130            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1131            &index,
1132            None,
1133            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1134        )
1135        .unwrap();
1136    }
1137
1138    #[test]
1139    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1140        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1141        let index = index.slice(2, 4);
1142        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1143
1144        assert_eq!(
1145            index,
1146            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1147        );
1148
1149        test_take_primitive_arrays_non_null::<Int64Type>(
1150            vec![0, 10, 20, 30, 40, 50],
1151            index,
1152            None,
1153            vec![Some(20), Some(30), None, None],
1154        )
1155        .unwrap();
1156    }
1157
1158    #[test]
1159    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1160        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1161        let index = index.slice(2, 4);
1162        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1163
1164        assert_eq!(
1165            index,
1166            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1167        );
1168
1169        test_take_primitive_arrays::<Int64Type>(
1170            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1171            index,
1172            None,
1173            vec![Some(20), Some(30), None, None],
1174        )
1175        .unwrap();
1176    }
1177
1178    #[test]
1179    fn test_take_primitive() {
1180        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1181
1182        // int8
1183        test_take_primitive_arrays::<Int8Type>(
1184            vec![Some(0), None, Some(2), Some(3), None],
1185            &index,
1186            None,
1187            vec![Some(3), None, None, Some(3), Some(2)],
1188        )
1189        .unwrap();
1190
1191        // int16
1192        test_take_primitive_arrays::<Int16Type>(
1193            vec![Some(0), None, Some(2), Some(3), None],
1194            &index,
1195            None,
1196            vec![Some(3), None, None, Some(3), Some(2)],
1197        )
1198        .unwrap();
1199
1200        // int32
1201        test_take_primitive_arrays::<Int32Type>(
1202            vec![Some(0), None, Some(2), Some(3), None],
1203            &index,
1204            None,
1205            vec![Some(3), None, None, Some(3), Some(2)],
1206        )
1207        .unwrap();
1208
1209        // int64
1210        test_take_primitive_arrays::<Int64Type>(
1211            vec![Some(0), None, Some(2), Some(3), None],
1212            &index,
1213            None,
1214            vec![Some(3), None, None, Some(3), Some(2)],
1215        )
1216        .unwrap();
1217
1218        // uint8
1219        test_take_primitive_arrays::<UInt8Type>(
1220            vec![Some(0), None, Some(2), Some(3), None],
1221            &index,
1222            None,
1223            vec![Some(3), None, None, Some(3), Some(2)],
1224        )
1225        .unwrap();
1226
1227        // uint16
1228        test_take_primitive_arrays::<UInt16Type>(
1229            vec![Some(0), None, Some(2), Some(3), None],
1230            &index,
1231            None,
1232            vec![Some(3), None, None, Some(3), Some(2)],
1233        )
1234        .unwrap();
1235
1236        // uint32
1237        test_take_primitive_arrays::<UInt32Type>(
1238            vec![Some(0), None, Some(2), Some(3), None],
1239            &index,
1240            None,
1241            vec![Some(3), None, None, Some(3), Some(2)],
1242        )
1243        .unwrap();
1244
1245        // int64
1246        test_take_primitive_arrays::<Int64Type>(
1247            vec![Some(0), None, Some(2), Some(-15), None],
1248            &index,
1249            None,
1250            vec![Some(-15), None, None, Some(-15), Some(2)],
1251        )
1252        .unwrap();
1253
1254        // interval_year_month
1255        test_take_primitive_arrays::<IntervalYearMonthType>(
1256            vec![Some(0), None, Some(2), Some(-15), None],
1257            &index,
1258            None,
1259            vec![Some(-15), None, None, Some(-15), Some(2)],
1260        )
1261        .unwrap();
1262
1263        // interval_day_time
1264        let v1 = IntervalDayTime::new(0, 0);
1265        let v2 = IntervalDayTime::new(2, 0);
1266        let v3 = IntervalDayTime::new(-15, 0);
1267        test_take_primitive_arrays::<IntervalDayTimeType>(
1268            vec![Some(v1), None, Some(v2), Some(v3), None],
1269            &index,
1270            None,
1271            vec![Some(v3), None, None, Some(v3), Some(v2)],
1272        )
1273        .unwrap();
1274
1275        // interval_month_day_nano
1276        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1277        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1278        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1279        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1280            vec![Some(v1), None, Some(v2), Some(v3), None],
1281            &index,
1282            None,
1283            vec![Some(v3), None, None, Some(v3), Some(v2)],
1284        )
1285        .unwrap();
1286
1287        // duration_second
1288        test_take_primitive_arrays::<DurationSecondType>(
1289            vec![Some(0), None, Some(2), Some(-15), None],
1290            &index,
1291            None,
1292            vec![Some(-15), None, None, Some(-15), Some(2)],
1293        )
1294        .unwrap();
1295
1296        // duration_millisecond
1297        test_take_primitive_arrays::<DurationMillisecondType>(
1298            vec![Some(0), None, Some(2), Some(-15), None],
1299            &index,
1300            None,
1301            vec![Some(-15), None, None, Some(-15), Some(2)],
1302        )
1303        .unwrap();
1304
1305        // duration_microsecond
1306        test_take_primitive_arrays::<DurationMicrosecondType>(
1307            vec![Some(0), None, Some(2), Some(-15), None],
1308            &index,
1309            None,
1310            vec![Some(-15), None, None, Some(-15), Some(2)],
1311        )
1312        .unwrap();
1313
1314        // duration_nanosecond
1315        test_take_primitive_arrays::<DurationNanosecondType>(
1316            vec![Some(0), None, Some(2), Some(-15), None],
1317            &index,
1318            None,
1319            vec![Some(-15), None, None, Some(-15), Some(2)],
1320        )
1321        .unwrap();
1322
1323        // float32
1324        test_take_primitive_arrays::<Float32Type>(
1325            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1326            &index,
1327            None,
1328            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1329        )
1330        .unwrap();
1331
1332        // float64
1333        test_take_primitive_arrays::<Float64Type>(
1334            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1335            &index,
1336            None,
1337            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1338        )
1339        .unwrap();
1340    }
1341
1342    #[test]
1343    fn test_take_preserve_timezone() {
1344        let index = Int64Array::from(vec![Some(0), None]);
1345
1346        let input = TimestampNanosecondArray::from(vec![
1347            1_639_715_368_000_000_000,
1348            1_639_715_368_000_000_000,
1349        ])
1350        .with_timezone("UTC".to_string());
1351        let result = take(&input, &index, None).unwrap();
1352        match result.data_type() {
1353            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1354                assert_eq!(tz.clone(), Some("UTC".into()))
1355            }
1356            _ => panic!(),
1357        }
1358    }
1359
1360    #[test]
1361    fn test_take_impl_primitive_with_int64_indices() {
1362        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1363
1364        // int16
1365        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1366            vec![Some(0), None, Some(2), Some(3), None],
1367            &index,
1368            None,
1369            vec![Some(3), None, None, Some(3), Some(2)],
1370        );
1371
1372        // int64
1373        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1374            vec![Some(0), None, Some(2), Some(-15), None],
1375            &index,
1376            None,
1377            vec![Some(-15), None, None, Some(-15), Some(2)],
1378        );
1379
1380        // uint64
1381        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1382            vec![Some(0), None, Some(2), Some(3), None],
1383            &index,
1384            None,
1385            vec![Some(3), None, None, Some(3), Some(2)],
1386        );
1387
1388        // duration_millisecond
1389        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1390            vec![Some(0), None, Some(2), Some(-15), None],
1391            &index,
1392            None,
1393            vec![Some(-15), None, None, Some(-15), Some(2)],
1394        );
1395
1396        // float32
1397        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1398            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1399            &index,
1400            None,
1401            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1402        );
1403    }
1404
1405    #[test]
1406    fn test_take_impl_primitive_with_uint8_indices() {
1407        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1408
1409        // int16
1410        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1411            vec![Some(0), None, Some(2), Some(3), None],
1412            &index,
1413            None,
1414            vec![Some(3), None, None, Some(3), Some(2)],
1415        );
1416
1417        // duration_millisecond
1418        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1419            vec![Some(0), None, Some(2), Some(-15), None],
1420            &index,
1421            None,
1422            vec![Some(-15), None, None, Some(-15), Some(2)],
1423        );
1424
1425        // float32
1426        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1427            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1428            &index,
1429            None,
1430            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1431        );
1432    }
1433
1434    #[test]
1435    fn test_take_bool() {
1436        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1437        // boolean
1438        test_take_boolean_arrays(
1439            vec![Some(false), None, Some(true), Some(false), None],
1440            &index,
1441            None,
1442            vec![Some(false), None, None, Some(false), Some(true)],
1443        );
1444    }
1445
1446    #[test]
1447    fn test_take_bool_nullable_index() {
1448        // indices where the masked invalid elements would be out of bounds
1449        let index_data = ArrayData::try_new(
1450            DataType::UInt32,
1451            6,
1452            Some(Buffer::from_iter(vec![
1453                false, true, false, true, false, true,
1454            ])),
1455            0,
1456            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1457            vec![],
1458        )
1459        .unwrap();
1460        let index = UInt32Array::from(index_data);
1461        test_take_boolean_arrays(
1462            vec![Some(true), None, Some(false)],
1463            &index,
1464            None,
1465            vec![None, Some(true), None, None, None, Some(false)],
1466        );
1467    }
1468
1469    #[test]
1470    fn test_take_bool_nullable_index_nonnull_values() {
1471        // indices where the masked invalid elements would be out of bounds
1472        let index_data = ArrayData::try_new(
1473            DataType::UInt32,
1474            6,
1475            Some(Buffer::from_iter(vec![
1476                false, true, false, true, false, true,
1477            ])),
1478            0,
1479            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1480            vec![],
1481        )
1482        .unwrap();
1483        let index = UInt32Array::from(index_data);
1484        test_take_boolean_arrays(
1485            vec![Some(true), Some(true), Some(false)],
1486            &index,
1487            None,
1488            vec![None, Some(true), None, Some(true), None, Some(false)],
1489        );
1490    }
1491
1492    #[test]
1493    fn test_take_bool_with_offset() {
1494        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1495        let index = index.slice(2, 4);
1496        let index = index
1497            .as_any()
1498            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1499            .unwrap();
1500
1501        // boolean
1502        test_take_boolean_arrays(
1503            vec![Some(false), None, Some(true), Some(false), None],
1504            index,
1505            None,
1506            vec![None, Some(false), Some(true), None],
1507        );
1508    }
1509
1510    fn _test_take_string<'a, K>()
1511    where
1512        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1513    {
1514        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1515
1516        let array = K::from(vec![
1517            Some("one"),
1518            None,
1519            Some("three"),
1520            Some("four"),
1521            Some("five"),
1522        ]);
1523        let actual = take(&array, &index, None).unwrap();
1524        assert_eq!(actual.len(), index.len());
1525
1526        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1527
1528        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1529
1530        assert_eq!(actual, &expected);
1531    }
1532
1533    #[test]
1534    fn test_take_string() {
1535        _test_take_string::<StringArray>()
1536    }
1537
1538    #[test]
1539    fn test_take_large_string() {
1540        _test_take_string::<LargeStringArray>()
1541    }
1542
1543    #[test]
1544    fn test_take_slice_string() {
1545        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1546        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1547        let indices_slice = indices.slice(1, 4);
1548        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1549        let result = take(&strings, &indices_slice, None).unwrap();
1550        assert_eq!(result.as_ref(), &expected);
1551    }
1552
1553    fn _test_byte_view<T>()
1554    where
1555        T: ByteViewType,
1556        str: AsRef<T::Native>,
1557        T::Native: PartialEq,
1558    {
1559        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1560        let array = {
1561            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1562            let mut builder = GenericByteViewBuilder::<T>::new();
1563            builder.append_value("hello");
1564            builder.append_value("world");
1565            builder.append_null();
1566            builder.append_value("large payload over 12 bytes");
1567            builder.append_value("lulu");
1568            builder.finish()
1569        };
1570
1571        let actual = take(&array, &index, None).unwrap();
1572
1573        assert_eq!(actual.len(), index.len());
1574
1575        let expected = {
1576            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1577            let mut builder = GenericByteViewBuilder::<T>::new();
1578            builder.append_value("large payload over 12 bytes");
1579            builder.append_null();
1580            builder.append_value("world");
1581            builder.append_value("large payload over 12 bytes");
1582            builder.append_value("lulu");
1583            builder.append_null();
1584            builder.finish()
1585        };
1586
1587        assert_eq!(actual.as_ref(), &expected);
1588    }
1589
1590    #[test]
1591    fn test_take_string_view() {
1592        _test_byte_view::<StringViewType>()
1593    }
1594
1595    #[test]
1596    fn test_take_binary_view() {
1597        _test_byte_view::<BinaryViewType>()
1598    }
1599
1600    macro_rules! test_take_list {
1601        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1602            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1603            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1604            // Construct offsets
1605            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1606            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1607            // Construct a list array from the above two
1608            let list_data_type =
1609                DataType::$list_data_type(Arc::new(Field::new("item", DataType::Int32, false)));
1610            let list_data = ArrayData::builder(list_data_type.clone())
1611                .len(4)
1612                .add_buffer(value_offsets)
1613                .add_child_data(value_data)
1614                .build()
1615                .unwrap();
1616            let list_array = $list_array_type::from(list_data);
1617
1618            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1619            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1620
1621            let a = take(&list_array, &index, None).unwrap();
1622            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1623
1624            // construct a value array with expected results:
1625            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1626            let expected_data = Int32Array::from(vec![
1627                Some(2),
1628                Some(3),
1629                Some(-1),
1630                Some(-2),
1631                Some(-1),
1632                Some(0),
1633                Some(0),
1634                Some(0),
1635            ])
1636            .into_data();
1637            // construct offsets
1638            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1639            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1640            // construct list array from the two
1641            let expected_list_data = ArrayData::builder(list_data_type)
1642                .len(5)
1643                // null buffer remains the same as only the indices have nulls
1644                .nulls(index.nulls().cloned())
1645                .add_buffer(expected_offsets)
1646                .add_child_data(expected_data)
1647                .build()
1648                .unwrap();
1649            let expected_list_array = $list_array_type::from(expected_list_data);
1650
1651            assert_eq!(a, &expected_list_array);
1652        }};
1653    }
1654
1655    macro_rules! test_take_list_with_value_nulls {
1656        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1657            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1658            let value_data = Int32Array::from(vec![
1659                Some(0),
1660                None,
1661                Some(0),
1662                Some(-1),
1663                Some(-2),
1664                Some(3),
1665                None,
1666                Some(5),
1667                None,
1668            ])
1669            .into_data();
1670            // Construct offsets
1671            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1672            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1673            // Construct a list array from the above two
1674            let list_data_type =
1675                DataType::$list_data_type(Arc::new(Field::new("item", DataType::Int32, true)));
1676            let list_data = ArrayData::builder(list_data_type.clone())
1677                .len(4)
1678                .add_buffer(value_offsets)
1679                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1680                .add_child_data(value_data)
1681                .build()
1682                .unwrap();
1683            let list_array = $list_array_type::from(list_data);
1684
1685            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1686            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1687
1688            let a = take(&list_array, &index, None).unwrap();
1689            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1690
1691            // construct a value array with expected results:
1692            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1693            let expected_data = Int32Array::from(vec![
1694                None,
1695                Some(-1),
1696                Some(-2),
1697                Some(3),
1698                Some(5),
1699                None,
1700                Some(0),
1701                None,
1702                Some(0),
1703            ])
1704            .into_data();
1705            // construct offsets
1706            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1707            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1708            // construct list array from the two
1709            let expected_list_data = ArrayData::builder(list_data_type)
1710                .len(5)
1711                // null buffer remains the same as only the indices have nulls
1712                .nulls(index.nulls().cloned())
1713                .add_buffer(expected_offsets)
1714                .add_child_data(expected_data)
1715                .build()
1716                .unwrap();
1717            let expected_list_array = $list_array_type::from(expected_list_data);
1718
1719            assert_eq!(a, &expected_list_array);
1720        }};
1721    }
1722
1723    macro_rules! test_take_list_with_nulls {
1724        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1725            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1726            let value_data = Int32Array::from(vec![
1727                Some(0),
1728                None,
1729                Some(0),
1730                Some(-1),
1731                Some(-2),
1732                Some(3),
1733                Some(5),
1734                None,
1735            ])
1736            .into_data();
1737            // Construct offsets
1738            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1739            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1740            // Construct a list array from the above two
1741            let list_data_type =
1742                DataType::$list_data_type(Arc::new(Field::new("item", DataType::Int32, true)));
1743            let list_data = ArrayData::builder(list_data_type.clone())
1744                .len(4)
1745                .add_buffer(value_offsets)
1746                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1747                .add_child_data(value_data)
1748                .build()
1749                .unwrap();
1750            let list_array = $list_array_type::from(list_data);
1751
1752            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1753            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1754
1755            let a = take(&list_array, &index, None).unwrap();
1756            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1757
1758            // construct a value array with expected results:
1759            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1760            let expected_data = Int32Array::from(vec![
1761                Some(-1),
1762                Some(-2),
1763                Some(3),
1764                Some(5),
1765                None,
1766                Some(0),
1767                None,
1768                Some(0),
1769            ])
1770            .into_data();
1771            // construct offsets
1772            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1773            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1774            // construct list array from the two
1775            let mut null_bits: [u8; 1] = [0; 1];
1776            bit_util::set_bit(&mut null_bits, 2);
1777            bit_util::set_bit(&mut null_bits, 3);
1778            bit_util::set_bit(&mut null_bits, 4);
1779            let expected_list_data = ArrayData::builder(list_data_type)
1780                .len(5)
1781                // null buffer must be recalculated as both values and indices have nulls
1782                .null_bit_buffer(Some(Buffer::from(null_bits)))
1783                .add_buffer(expected_offsets)
1784                .add_child_data(expected_data)
1785                .build()
1786                .unwrap();
1787            let expected_list_array = $list_array_type::from(expected_list_data);
1788
1789            assert_eq!(a, &expected_list_array);
1790        }};
1791    }
1792
1793    fn do_take_fixed_size_list_test<T>(
1794        length: <Int32Type as ArrowPrimitiveType>::Native,
1795        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1796        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1797        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1798    ) where
1799        T: ArrowPrimitiveType,
1800        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1801    {
1802        let indices = UInt32Array::from(indices);
1803
1804        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1805
1806        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1807
1808        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1809
1810        assert_eq!(&output, &expected)
1811    }
1812
1813    #[test]
1814    fn test_take_list() {
1815        test_take_list!(i32, List, ListArray);
1816    }
1817
1818    #[test]
1819    fn test_take_large_list() {
1820        test_take_list!(i64, LargeList, LargeListArray);
1821    }
1822
1823    #[test]
1824    fn test_take_list_with_value_nulls() {
1825        test_take_list_with_value_nulls!(i32, List, ListArray);
1826    }
1827
1828    #[test]
1829    fn test_take_large_list_with_value_nulls() {
1830        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1831    }
1832
1833    #[test]
1834    fn test_test_take_list_with_nulls() {
1835        test_take_list_with_nulls!(i32, List, ListArray);
1836    }
1837
1838    #[test]
1839    fn test_test_take_large_list_with_nulls() {
1840        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1841    }
1842
1843    #[test]
1844    fn test_take_fixed_size_list() {
1845        do_take_fixed_size_list_test::<Int32Type>(
1846            3,
1847            vec![
1848                Some(vec![None, Some(1), Some(2)]),
1849                Some(vec![Some(3), Some(4), None]),
1850                Some(vec![Some(6), Some(7), Some(8)]),
1851            ],
1852            vec![2, 1, 0],
1853            vec![
1854                Some(vec![Some(6), Some(7), Some(8)]),
1855                Some(vec![Some(3), Some(4), None]),
1856                Some(vec![None, Some(1), Some(2)]),
1857            ],
1858        );
1859
1860        do_take_fixed_size_list_test::<UInt8Type>(
1861            1,
1862            vec![
1863                Some(vec![Some(1)]),
1864                Some(vec![Some(2)]),
1865                Some(vec![Some(3)]),
1866                Some(vec![Some(4)]),
1867                Some(vec![Some(5)]),
1868                Some(vec![Some(6)]),
1869                Some(vec![Some(7)]),
1870                Some(vec![Some(8)]),
1871            ],
1872            vec![2, 7, 0],
1873            vec![
1874                Some(vec![Some(3)]),
1875                Some(vec![Some(8)]),
1876                Some(vec![Some(1)]),
1877            ],
1878        );
1879
1880        do_take_fixed_size_list_test::<UInt64Type>(
1881            3,
1882            vec![
1883                Some(vec![Some(10), Some(11), Some(12)]),
1884                Some(vec![Some(13), Some(14), Some(15)]),
1885                None,
1886                Some(vec![Some(16), Some(17), Some(18)]),
1887            ],
1888            vec![3, 2, 1, 2, 0],
1889            vec![
1890                Some(vec![Some(16), Some(17), Some(18)]),
1891                None,
1892                Some(vec![Some(13), Some(14), Some(15)]),
1893                None,
1894                Some(vec![Some(10), Some(11), Some(12)]),
1895            ],
1896        );
1897    }
1898
1899    #[test]
1900    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1901    fn test_take_list_out_of_bounds() {
1902        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
1903        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1904        // Construct offsets
1905        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1906        // Construct a list array from the above two
1907        let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
1908        let list_data = ArrayData::builder(list_data_type)
1909            .len(3)
1910            .add_buffer(value_offsets)
1911            .add_child_data(value_data)
1912            .build()
1913            .unwrap();
1914        let list_array = ListArray::from(list_data);
1915
1916        let index = UInt32Array::from(vec![1000]);
1917
1918        // A panic is expected here since we have not supplied the check_bounds
1919        // option.
1920        take(&list_array, &index, None).unwrap();
1921    }
1922
1923    #[test]
1924    fn test_take_map() {
1925        let values = Int32Array::from(vec![1, 2, 3, 4]);
1926        let array =
1927            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1928                .unwrap();
1929
1930        let index = UInt32Array::from(vec![0]);
1931
1932        let result = take(&array, &index, None).unwrap();
1933        let expected: ArrayRef = Arc::new(
1934            MapArray::new_from_strings(
1935                vec!["a", "b", "c"].into_iter(),
1936                &values.slice(0, 3),
1937                &[0, 3],
1938            )
1939            .unwrap(),
1940        );
1941        assert_eq!(&expected, &result);
1942    }
1943
1944    #[test]
1945    fn test_take_struct() {
1946        let array = create_test_struct(vec![
1947            Some((Some(true), Some(42))),
1948            Some((Some(false), Some(28))),
1949            Some((Some(false), Some(19))),
1950            Some((Some(true), Some(31))),
1951            None,
1952        ]);
1953
1954        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1955        let actual = take(&array, &index, None).unwrap();
1956        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1957        assert_eq!(index.len(), actual.len());
1958        assert_eq!(1, actual.null_count());
1959
1960        let expected = create_test_struct(vec![
1961            Some((Some(true), Some(42))),
1962            Some((Some(true), Some(31))),
1963            Some((Some(false), Some(28))),
1964            Some((Some(true), Some(42))),
1965            Some((Some(false), Some(19))),
1966            None,
1967        ]);
1968
1969        assert_eq!(&expected, actual);
1970    }
1971
1972    #[test]
1973    fn test_take_struct_with_null_indices() {
1974        let array = create_test_struct(vec![
1975            Some((Some(true), Some(42))),
1976            Some((Some(false), Some(28))),
1977            Some((Some(false), Some(19))),
1978            Some((Some(true), Some(31))),
1979            None,
1980        ]);
1981
1982        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
1983        let actual = take(&array, &index, None).unwrap();
1984        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1985        assert_eq!(index.len(), actual.len());
1986        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
1987
1988        let expected = create_test_struct(vec![
1989            None,
1990            Some((Some(true), Some(31))),
1991            Some((Some(false), Some(28))),
1992            None,
1993            Some((Some(true), Some(42))),
1994            None,
1995        ]);
1996
1997        assert_eq!(&expected, actual);
1998    }
1999
2000    #[test]
2001    fn test_take_out_of_bounds() {
2002        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2003        let take_opt = TakeOptions { check_bounds: true };
2004
2005        // int64
2006        let result = test_take_primitive_arrays::<Int64Type>(
2007            vec![Some(0), None, Some(2), Some(3), None],
2008            &index,
2009            Some(take_opt),
2010            vec![None],
2011        );
2012        assert!(result.is_err());
2013    }
2014
2015    #[test]
2016    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2017    fn test_take_out_of_bounds_panic() {
2018        let index = UInt32Array::from(vec![Some(1000)]);
2019
2020        test_take_primitive_arrays::<Int64Type>(
2021            vec![Some(0), Some(1), Some(2), Some(3)],
2022            &index,
2023            None,
2024            vec![None],
2025        )
2026        .unwrap();
2027    }
2028
2029    #[test]
2030    fn test_null_array_smaller_than_indices() {
2031        let values = NullArray::new(2);
2032        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2033
2034        let result = take(&values, &indices, None).unwrap();
2035        let expected: ArrayRef = Arc::new(NullArray::new(3));
2036        assert_eq!(&result, &expected);
2037    }
2038
2039    #[test]
2040    fn test_null_array_larger_than_indices() {
2041        let values = NullArray::new(5);
2042        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2043
2044        let result = take(&values, &indices, None).unwrap();
2045        let expected: ArrayRef = Arc::new(NullArray::new(3));
2046        assert_eq!(&result, &expected);
2047    }
2048
2049    #[test]
2050    fn test_null_array_indices_out_of_bounds() {
2051        let values = NullArray::new(5);
2052        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2053
2054        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2055        assert_eq!(
2056            result.unwrap_err().to_string(),
2057            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2058        );
2059    }
2060
2061    #[test]
2062    fn test_take_dict() {
2063        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2064
2065        dict_builder.append("foo").unwrap();
2066        dict_builder.append("bar").unwrap();
2067        dict_builder.append("").unwrap();
2068        dict_builder.append_null();
2069        dict_builder.append("foo").unwrap();
2070        dict_builder.append("bar").unwrap();
2071        dict_builder.append("bar").unwrap();
2072        dict_builder.append("foo").unwrap();
2073
2074        let array = dict_builder.finish();
2075        let dict_values = array.values().clone();
2076        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2077
2078        let indices = UInt32Array::from(vec![
2079            Some(0), // first "foo"
2080            Some(7), // last "foo"
2081            None,    // null index should return null
2082            Some(5), // second "bar"
2083            Some(6), // another "bar"
2084            Some(2), // empty string
2085            Some(3), // input is null at this index
2086        ]);
2087
2088        let result = take(&array, &indices, None).unwrap();
2089        let result = result
2090            .as_any()
2091            .downcast_ref::<DictionaryArray<Int16Type>>()
2092            .unwrap();
2093
2094        let result_values: StringArray = result.values().to_data().into();
2095
2096        // dictionary values should stay the same
2097        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2098        assert_eq!(&expected_values, dict_values);
2099        assert_eq!(&expected_values, &result_values);
2100
2101        let expected_keys = Int16Array::from(vec![
2102            Some(0),
2103            Some(0),
2104            None,
2105            Some(1),
2106            Some(1),
2107            Some(2),
2108            None,
2109        ]);
2110        assert_eq!(result.keys(), &expected_keys);
2111    }
2112
2113    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2114    where
2115        S: OffsetSizeTrait + 'static,
2116        T: ArrowPrimitiveType,
2117        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2118    {
2119        GenericListArray::from_iter_primitive::<T, _, _>(
2120            data.iter()
2121                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2122        )
2123    }
2124
2125    #[test]
2126    fn test_take_value_index_from_list() {
2127        let list = build_generic_list::<i32, Int32Type>(vec![
2128            Some(vec![0, 1]),
2129            Some(vec![2, 3, 4]),
2130            Some(vec![5, 6, 7, 8, 9]),
2131        ]);
2132        let indices = UInt32Array::from(vec![2, 0]);
2133
2134        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2135
2136        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2137        assert_eq!(offsets, vec![0, 5, 7]);
2138        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2139    }
2140
2141    #[test]
2142    fn test_take_value_index_from_large_list() {
2143        let list = build_generic_list::<i64, Int32Type>(vec![
2144            Some(vec![0, 1]),
2145            Some(vec![2, 3, 4]),
2146            Some(vec![5, 6, 7, 8, 9]),
2147        ]);
2148        let indices = UInt32Array::from(vec![2, 0]);
2149
2150        let (indexed, offsets, null_buf) =
2151            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2152
2153        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2154        assert_eq!(offsets, vec![0, 5, 7]);
2155        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2156    }
2157
2158    #[test]
2159    fn test_take_runs() {
2160        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2161
2162        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2163        builder.extend(logical_array.into_iter().map(Some));
2164        let run_array = builder.finish();
2165
2166        let take_indices: PrimitiveArray<Int32Type> =
2167            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2168
2169        let take_out = take_run(&run_array, &take_indices).unwrap();
2170
2171        assert_eq!(take_out.len(), 7);
2172        assert_eq!(take_out.run_ends().len(), 7);
2173        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2174
2175        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2176        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2177    }
2178
2179    #[test]
2180    fn test_take_value_index_from_fixed_list() {
2181        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2182            vec![
2183                Some(vec![Some(1), Some(2), None]),
2184                Some(vec![Some(4), None, Some(6)]),
2185                None,
2186                Some(vec![None, Some(8), Some(9)]),
2187            ],
2188            3,
2189        );
2190
2191        let indices = UInt32Array::from(vec![2, 1, 0]);
2192        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2193
2194        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2195
2196        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2197        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2198
2199        assert_eq!(
2200            indexed,
2201            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2202        );
2203    }
2204
2205    #[test]
2206    fn test_take_null_indices() {
2207        // Build indices with values that are out of bounds, but masked by null mask
2208        let indices = Int32Array::new(
2209            vec![1, 2, 400, 400].into(),
2210            Some(NullBuffer::from(vec![true, true, false, false])),
2211        );
2212        let values = Int32Array::from(vec![1, 23, 4, 5]);
2213        let r = take(&values, &indices, None).unwrap();
2214        let values = r
2215            .as_primitive::<Int32Type>()
2216            .into_iter()
2217            .collect::<Vec<_>>();
2218        assert_eq!(&values, &[Some(23), Some(4), None, None])
2219    }
2220
2221    #[test]
2222    fn test_take_fixed_size_list_null_indices() {
2223        let indices = Int32Array::from_iter([Some(0), None]);
2224        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2225        let arr_field = Arc::new(Field::new("item", values.data_type().clone(), true));
2226        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2227
2228        let r = take(&values, &indices, None).unwrap();
2229        let values = r
2230            .as_fixed_size_list()
2231            .values()
2232            .as_primitive::<Int32Type>()
2233            .into_iter()
2234            .collect::<Vec<_>>();
2235        assert_eq!(values, &[Some(0), Some(1), None, None])
2236    }
2237
2238    #[test]
2239    fn test_take_bytes_null_indices() {
2240        let indices = Int32Array::new(
2241            vec![0, 1, 400, 400].into(),
2242            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2243        );
2244        let values = StringArray::from(vec![Some("foo"), None]);
2245        let r = take(&values, &indices, None).unwrap();
2246        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2247        assert_eq!(&values, &[Some("foo"), None, None, None])
2248    }
2249
2250    #[test]
2251    fn test_take_union_sparse() {
2252        let structs = create_test_struct(vec![
2253            Some((Some(true), Some(42))),
2254            Some((Some(false), Some(28))),
2255            Some((Some(false), Some(19))),
2256            Some((Some(true), Some(31))),
2257            None,
2258        ]);
2259        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2260        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2261
2262        let union_fields = [
2263            (
2264                0,
2265                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2266            ),
2267            (
2268                1,
2269                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2270            ),
2271        ]
2272        .into_iter()
2273        .collect();
2274        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2275        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2276
2277        let indices = vec![0, 3, 1, 0, 2, 4];
2278        let index = UInt32Array::from(indices.clone());
2279        let actual = take(&array, &index, None).unwrap();
2280        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2281        let strings = actual.child(1);
2282        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2283
2284        let actual = strings.iter().collect::<Vec<_>>();
2285        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2286        assert_eq!(expected, actual);
2287    }
2288
2289    #[test]
2290    fn test_take_union_dense() {
2291        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2292        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2293        let ints = vec![10, 20, 30, 40];
2294        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2295
2296        let indices = vec![0, 3, 1, 0, 2, 4];
2297
2298        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2299        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2300        let taken_ints = vec![10, 20, 10, 30];
2301        let taken_strings = vec![Some("a"), None];
2302
2303        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2304        let offsets = <ScalarBuffer<i32>>::from(offsets);
2305        let ints = UInt32Array::from(ints);
2306        let strings = StringArray::from(strings);
2307
2308        let union_fields = [
2309            (
2310                0,
2311                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2312            ),
2313            (
2314                1,
2315                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2316            ),
2317        ]
2318        .into_iter()
2319        .collect();
2320
2321        let array = UnionArray::try_new(
2322            union_fields,
2323            type_ids,
2324            Some(offsets),
2325            vec![Arc::new(ints), Arc::new(strings)],
2326        )
2327        .unwrap();
2328
2329        let index = UInt32Array::from(indices);
2330
2331        let actual = take(&array, &index, None).unwrap();
2332        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2333
2334        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2335        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2336        assert_eq!(
2337            UInt32Array::from(actual.child(0).to_data()),
2338            UInt32Array::from(taken_ints)
2339        );
2340        assert_eq!(
2341            StringArray::from(actual.child(1).to_data()),
2342            StringArray::from(taken_strings)
2343        );
2344    }
2345
2346    #[test]
2347    fn test_take_union_dense_using_builder() {
2348        let mut builder = UnionBuilder::new_dense();
2349
2350        builder.append::<Int32Type>("a", 1).unwrap();
2351        builder.append::<Float64Type>("b", 3.0).unwrap();
2352        builder.append::<Int32Type>("a", 4).unwrap();
2353        builder.append::<Int32Type>("a", 5).unwrap();
2354        builder.append::<Float64Type>("b", 2.0).unwrap();
2355
2356        let union = builder.build().unwrap();
2357
2358        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2359
2360        let mut builder = UnionBuilder::new_dense();
2361
2362        builder.append::<Int32Type>("a", 4).unwrap();
2363        builder.append::<Int32Type>("a", 1).unwrap();
2364        builder.append::<Float64Type>("b", 3.0).unwrap();
2365        builder.append::<Int32Type>("a", 4).unwrap();
2366
2367        let taken = builder.build().unwrap();
2368
2369        assert_eq!(
2370            taken.to_data(),
2371            take(&union, &indices, None).unwrap().to_data()
2372        );
2373    }
2374
2375    #[test]
2376    fn test_take_union_dense_all_match_issue_6206() {
2377        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2378        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2379
2380        let array = UnionArray::try_new(
2381            fields,
2382            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2383            Some(ScalarBuffer::from_iter(0_i32..5)),
2384            vec![ints],
2385        )
2386        .unwrap();
2387
2388        let indicies = Int64Array::from(vec![0, 2, 4]);
2389        let array = take(&array, &indicies, None).unwrap();
2390        assert_eq!(array.len(), 3);
2391    }
2392}