arrow_select/
zip.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Zip two arrays by some boolean mask. Where the mask evaluates `true` values of `truthy`
19
20use crate::filter::SlicesIterator;
21use arrow_array::*;
22use arrow_data::transform::MutableArrayData;
23use arrow_schema::ArrowError;
24
25/// Zip two arrays by some boolean mask. Where the mask evaluates `true` values of `truthy`
26/// are taken, where the mask evaluates `false` values of `falsy` are taken.
27///
28/// # Arguments
29/// * `mask` - Boolean values used to determine from which array to take the values.
30/// * `truthy` - Values of this array are taken if mask evaluates `true`
31/// * `falsy` - Values of this array are taken if mask evaluates `false`
32pub fn zip(
33    mask: &BooleanArray,
34    truthy: &dyn Datum,
35    falsy: &dyn Datum,
36) -> Result<ArrayRef, ArrowError> {
37    let (truthy, truthy_is_scalar) = truthy.get();
38    let (falsy, falsy_is_scalar) = falsy.get();
39
40    if truthy.data_type() != falsy.data_type() {
41        return Err(ArrowError::InvalidArgumentError(
42            "arguments need to have the same data type".into(),
43        ));
44    }
45
46    if truthy_is_scalar && truthy.len() != 1 {
47        return Err(ArrowError::InvalidArgumentError(
48            "scalar arrays must have 1 element".into(),
49        ));
50    }
51    if !truthy_is_scalar && truthy.len() != mask.len() {
52        return Err(ArrowError::InvalidArgumentError(
53            "all arrays should have the same length".into(),
54        ));
55    }
56    if falsy_is_scalar && falsy.len() != 1 {
57        return Err(ArrowError::InvalidArgumentError(
58            "scalar arrays must have 1 element".into(),
59        ));
60    }
61    if !falsy_is_scalar && falsy.len() != mask.len() {
62        return Err(ArrowError::InvalidArgumentError(
63            "all arrays should have the same length".into(),
64        ));
65    }
66
67    let falsy = falsy.to_data();
68    let truthy = truthy.to_data();
69
70    let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
71
72    // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
73    // fill with falsy values
74
75    // keep track of how much is filled
76    let mut filled = 0;
77
78    SlicesIterator::new(mask).for_each(|(start, end)| {
79        // the gap needs to be filled with falsy values
80        if start > filled {
81            if falsy_is_scalar {
82                for _ in filled..start {
83                    // Copy the first item from the 'falsy' array into the output buffer.
84                    mutable.extend(1, 0, 1);
85                }
86            } else {
87                mutable.extend(1, filled, start);
88            }
89        }
90        // fill with truthy values
91        if truthy_is_scalar {
92            for _ in start..end {
93                // Copy the first item from the 'truthy' array into the output buffer.
94                mutable.extend(0, 0, 1);
95            }
96        } else {
97            mutable.extend(0, start, end);
98        }
99        filled = end;
100    });
101    // the remaining part is falsy
102    if filled < mask.len() {
103        if falsy_is_scalar {
104            for _ in filled..mask.len() {
105                // Copy the first item from the 'falsy' array into the output buffer.
106                mutable.extend(1, 0, 1);
107            }
108        } else {
109            mutable.extend(1, filled, mask.len());
110        }
111    }
112
113    let data = mutable.freeze();
114    Ok(make_array(data))
115}
116
117#[cfg(test)]
118mod test {
119    use super::*;
120
121    #[test]
122    fn test_zip_kernel_one() {
123        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
124        let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
125        let mask = BooleanArray::from(vec![true, true, false, false, true]);
126        let out = zip(&mask, &a, &b).unwrap();
127        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
128        let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]);
129        assert_eq!(actual, &expected);
130    }
131
132    #[test]
133    fn test_zip_kernel_two() {
134        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
135        let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
136        let mask = BooleanArray::from(vec![false, false, true, true, false]);
137        let out = zip(&mask, &a, &b).unwrap();
138        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
139        let expected = Int32Array::from(vec![None, Some(3), Some(7), None, Some(3)]);
140        assert_eq!(actual, &expected);
141    }
142
143    #[test]
144    fn test_zip_kernel_scalar_falsy_1() {
145        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
146
147        let fallback = Scalar::new(Int32Array::from_value(42, 1));
148
149        let mask = BooleanArray::from(vec![true, true, false, false, true]);
150        let out = zip(&mask, &a, &fallback).unwrap();
151        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
152        let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
153        assert_eq!(actual, &expected);
154    }
155
156    #[test]
157    fn test_zip_kernel_scalar_falsy_2() {
158        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
159
160        let fallback = Scalar::new(Int32Array::from_value(42, 1));
161
162        let mask = BooleanArray::from(vec![false, false, true, true, false]);
163        let out = zip(&mask, &a, &fallback).unwrap();
164        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
165        let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
166        assert_eq!(actual, &expected);
167    }
168
169    #[test]
170    fn test_zip_kernel_scalar_truthy_1() {
171        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
172
173        let fallback = Scalar::new(Int32Array::from_value(42, 1));
174
175        let mask = BooleanArray::from(vec![true, true, false, false, true]);
176        let out = zip(&mask, &fallback, &a).unwrap();
177        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
178        let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
179        assert_eq!(actual, &expected);
180    }
181
182    #[test]
183    fn test_zip_kernel_scalar_truthy_2() {
184        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
185
186        let fallback = Scalar::new(Int32Array::from_value(42, 1));
187
188        let mask = BooleanArray::from(vec![false, false, true, true, false]);
189        let out = zip(&mask, &fallback, &a).unwrap();
190        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
191        let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
192        assert_eq!(actual, &expected);
193    }
194
195    #[test]
196    fn test_zip_kernel_scalar_both() {
197        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
198        let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
199
200        let mask = BooleanArray::from(vec![true, true, false, false, true]);
201        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
202        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
203        let expected = Int32Array::from(vec![Some(42), Some(42), Some(123), Some(123), Some(42)]);
204        assert_eq!(actual, &expected);
205    }
206
207    #[test]
208    fn test_zip_kernel_scalar_none_1() {
209        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
210        let scalar_falsy = Scalar::new(Int32Array::new_null(1));
211
212        let mask = BooleanArray::from(vec![true, true, false, false, true]);
213        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
214        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
215        let expected = Int32Array::from(vec![Some(42), Some(42), None, None, Some(42)]);
216        assert_eq!(actual, &expected);
217    }
218
219    #[test]
220    fn test_zip_kernel_scalar_none_2() {
221        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
222        let scalar_falsy = Scalar::new(Int32Array::new_null(1));
223
224        let mask = BooleanArray::from(vec![false, false, true, true, false]);
225        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
226        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
227        let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]);
228        assert_eq!(actual, &expected);
229    }
230}