1use arrow_array::builder::BufferBuilder;
21use arrow_array::types::ArrowDictionaryKeyType;
22use arrow_array::*;
23use arrow_buffer::buffer::NullBuffer;
24use arrow_buffer::ArrowNativeType;
25use arrow_buffer::{Buffer, MutableBuffer};
26use arrow_data::ArrayData;
27use arrow_schema::ArrowError;
28use std::sync::Arc;
29
30pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
32where
33 I: ArrowPrimitiveType,
34 O: ArrowPrimitiveType,
35 F: Fn(I::Native) -> O::Native,
36{
37 array.unary(op)
38}
39
40pub fn unary_mut<I, F>(
42 array: PrimitiveArray<I>,
43 op: F,
44) -> Result<PrimitiveArray<I>, PrimitiveArray<I>>
45where
46 I: ArrowPrimitiveType,
47 F: Fn(I::Native) -> I::Native,
48{
49 array.unary_mut(op)
50}
51
52pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>, ArrowError>
54where
55 I: ArrowPrimitiveType,
56 O: ArrowPrimitiveType,
57 F: Fn(I::Native) -> Result<O::Native, ArrowError>,
58{
59 array.try_unary(op)
60}
61
62pub fn try_unary_mut<I, F>(
64 array: PrimitiveArray<I>,
65 op: F,
66) -> Result<Result<PrimitiveArray<I>, ArrowError>, PrimitiveArray<I>>
67where
68 I: ArrowPrimitiveType,
69 F: Fn(I::Native) -> Result<I::Native, ArrowError>,
70{
71 array.try_unary_mut(op)
72}
73
74fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef, ArrowError>
76where
77 K: ArrowDictionaryKeyType + ArrowNumericType,
78 T: ArrowPrimitiveType,
79 F: Fn(T::Native) -> T::Native,
80{
81 let dict_values = array.values().as_any().downcast_ref().unwrap();
82 let values = unary::<T, F, T>(dict_values, op);
83 Ok(Arc::new(array.with_values(Arc::new(values))))
84}
85
86fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef, ArrowError>
88where
89 K: ArrowDictionaryKeyType + ArrowNumericType,
90 T: ArrowPrimitiveType,
91 F: Fn(T::Native) -> Result<T::Native, ArrowError>,
92{
93 if !PrimitiveArray::<T>::is_compatible(&array.value_type()) {
94 return Err(ArrowError::CastError(format!(
95 "Cannot perform the unary operation of type {} on dictionary array of value type {}",
96 T::DATA_TYPE,
97 array.value_type()
98 )));
99 }
100
101 let dict_values = array.values().as_any().downcast_ref().unwrap();
102 let values = try_unary::<T, F, T>(dict_values, op)?;
103 Ok(Arc::new(array.with_values(Arc::new(values))))
104}
105
106#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
108pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
109where
110 T: ArrowPrimitiveType,
111 F: Fn(T::Native) -> T::Native,
112{
113 downcast_dictionary_array! {
114 array => unary_dict::<_, F, T>(array, op),
115 t => {
116 if PrimitiveArray::<T>::is_compatible(t) {
117 Ok(Arc::new(unary::<T, F, T>(
118 array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
119 op,
120 )))
121 } else {
122 Err(ArrowError::NotYetImplemented(format!(
123 "Cannot perform unary operation of type {} on array of type {}",
124 T::DATA_TYPE,
125 t
126 )))
127 }
128 }
129 }
130}
131
132#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
134pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
135where
136 T: ArrowPrimitiveType,
137 F: Fn(T::Native) -> Result<T::Native, ArrowError>,
138{
139 downcast_dictionary_array! {
140 array => if array.values().data_type() == &T::DATA_TYPE {
141 try_unary_dict::<_, F, T>(array, op)
142 } else {
143 Err(ArrowError::NotYetImplemented(format!(
144 "Cannot perform unary operation on dictionary array of type {}",
145 array.data_type()
146 )))
147 },
148 t => {
149 if PrimitiveArray::<T>::is_compatible(t) {
150 Ok(Arc::new(try_unary::<T, F, T>(
151 array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
152 op,
153 )?))
154 } else {
155 Err(ArrowError::NotYetImplemented(format!(
156 "Cannot perform unary operation of type {} on array of type {}",
157 T::DATA_TYPE,
158 t
159 )))
160 }
161 }
162 }
163}
164
165pub fn binary<A, B, F, O>(
198 a: &PrimitiveArray<A>,
199 b: &PrimitiveArray<B>,
200 op: F,
201) -> Result<PrimitiveArray<O>, ArrowError>
202where
203 A: ArrowPrimitiveType,
204 B: ArrowPrimitiveType,
205 O: ArrowPrimitiveType,
206 F: Fn(A::Native, B::Native) -> O::Native,
207{
208 if a.len() != b.len() {
209 return Err(ArrowError::ComputeError(
210 "Cannot perform binary operation on arrays of different length".to_string(),
211 ));
212 }
213
214 if a.is_empty() {
215 return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
216 }
217
218 let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());
219
220 let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r));
221 let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
227 Ok(PrimitiveArray::new(buffer.into(), nulls))
228}
229
230pub fn binary_mut<T, U, F>(
295 a: PrimitiveArray<T>,
296 b: &PrimitiveArray<U>,
297 op: F,
298) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
299where
300 T: ArrowPrimitiveType,
301 U: ArrowPrimitiveType,
302 F: Fn(T::Native, U::Native) -> T::Native,
303{
304 if a.len() != b.len() {
305 return Ok(Err(ArrowError::ComputeError(
306 "Cannot perform binary operation on arrays of different length".to_string(),
307 )));
308 }
309
310 if a.is_empty() {
311 return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
312 &T::DATA_TYPE,
313 ))));
314 }
315
316 let mut builder = a.into_builder()?;
317
318 builder
319 .values_slice_mut()
320 .iter_mut()
321 .zip(b.values())
322 .for_each(|(l, r)| *l = op(*l, *r));
323
324 let array = builder.finish();
325
326 let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref());
328
329 let array_builder = array.into_data().into_builder().nulls(nulls);
330
331 let array_data = unsafe { array_builder.build_unchecked() };
332 Ok(Ok(PrimitiveArray::<T>::from(array_data)))
333}
334
335pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
348 a: A,
349 b: B,
350 op: F,
351) -> Result<PrimitiveArray<O>, ArrowError>
352where
353 O: ArrowPrimitiveType,
354 F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
355{
356 if a.len() != b.len() {
357 return Err(ArrowError::ComputeError(
358 "Cannot perform a binary operation on arrays of different length".to_string(),
359 ));
360 }
361 if a.is_empty() {
362 return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
363 }
364 let len = a.len();
365
366 if a.null_count() == 0 && b.null_count() == 0 {
367 try_binary_no_nulls(len, a, b, op)
368 } else {
369 let nulls =
370 NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap();
371
372 let mut buffer = BufferBuilder::<O::Native>::new(len);
373 buffer.append_n_zeroed(len);
374 let slice = buffer.as_slice_mut();
375
376 nulls.try_for_each_valid_idx(|idx| {
377 unsafe {
378 *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))?
379 };
380 Ok::<_, ArrowError>(())
381 })?;
382
383 let values = buffer.finish().into();
384 Ok(PrimitiveArray::new(values, Some(nulls)))
385 }
386}
387
388pub fn try_binary_mut<T, F>(
399 a: PrimitiveArray<T>,
400 b: &PrimitiveArray<T>,
401 op: F,
402) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
403where
404 T: ArrowPrimitiveType,
405 F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
406{
407 if a.len() != b.len() {
408 return Ok(Err(ArrowError::ComputeError(
409 "Cannot perform binary operation on arrays of different length".to_string(),
410 )));
411 }
412 let len = a.len();
413
414 if a.is_empty() {
415 return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
416 &T::DATA_TYPE,
417 ))));
418 }
419
420 if a.null_count() == 0 && b.null_count() == 0 {
421 try_binary_no_nulls_mut(len, a, b, op)
422 } else {
423 let nulls =
424 create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
425 .unwrap();
426
427 let mut builder = a.into_builder()?;
428
429 let slice = builder.values_slice_mut();
430
431 let r = nulls.try_for_each_valid_idx(|idx| {
432 unsafe {
433 *slice.get_unchecked_mut(idx) =
434 op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
435 };
436 Ok::<_, ArrowError>(())
437 });
438 if let Err(err) = r {
439 return Ok(Err(err));
440 }
441 let array_builder = builder.finish().into_data().into_builder();
442 let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() };
443 Ok(Ok(PrimitiveArray::<T>::from(array_data)))
444 }
445}
446
447fn create_union_null_buffer(
453 lhs: Option<&NullBuffer>,
454 rhs: Option<&NullBuffer>,
455) -> Option<NullBuffer> {
456 match (lhs, rhs) {
457 (Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())),
458 (Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())),
459 (None, None) => None,
460 }
461}
462
463#[inline(never)]
465fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
466 len: usize,
467 a: A,
468 b: B,
469 op: F,
470) -> Result<PrimitiveArray<O>, ArrowError>
471where
472 O: ArrowPrimitiveType,
473 F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
474{
475 let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
476 for idx in 0..len {
477 unsafe {
478 buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
479 };
480 }
481 Ok(PrimitiveArray::new(buffer.into(), None))
482}
483
484#[inline(never)]
486fn try_binary_no_nulls_mut<T, F>(
487 len: usize,
488 a: PrimitiveArray<T>,
489 b: &PrimitiveArray<T>,
490 op: F,
491) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
492where
493 T: ArrowPrimitiveType,
494 F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
495{
496 let mut builder = a.into_builder()?;
497 let slice = builder.values_slice_mut();
498
499 for idx in 0..len {
500 unsafe {
501 match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
502 Ok(value) => *slice.get_unchecked_mut(idx) = value,
503 Err(err) => return Ok(Err(err)),
504 };
505 };
506 }
507 Ok(Ok(builder.finish()))
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use arrow_array::builder::*;
514 use arrow_array::types::*;
515
516 #[test]
517 #[allow(deprecated)]
518 fn test_unary_f64_slice() {
519 let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
520 let input_slice = input.slice(1, 4);
521 let result = unary(&input_slice, |n| n.round());
522 assert_eq!(
523 result,
524 Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
525 );
526
527 let result = unary_dyn::<_, Float64Type>(&input_slice, |n| n + 1.0).unwrap();
528
529 assert_eq!(
530 result.as_any().downcast_ref::<Float64Array>().unwrap(),
531 &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
532 );
533 }
534
535 #[test]
536 #[allow(deprecated)]
537 fn test_unary_dict_and_unary_dyn() {
538 let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
539 builder.append(5).unwrap();
540 builder.append(6).unwrap();
541 builder.append(7).unwrap();
542 builder.append(8).unwrap();
543 builder.append_null();
544 builder.append(9).unwrap();
545 let dictionary_array = builder.finish();
546
547 let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
548 builder.append(6).unwrap();
549 builder.append(7).unwrap();
550 builder.append(8).unwrap();
551 builder.append(9).unwrap();
552 builder.append_null();
553 builder.append(10).unwrap();
554 let expected = builder.finish();
555
556 let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
557 assert_eq!(
558 result
559 .as_any()
560 .downcast_ref::<DictionaryArray<Int8Type>>()
561 .unwrap(),
562 &expected
563 );
564
565 let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
566 assert_eq!(
567 result
568 .as_any()
569 .downcast_ref::<DictionaryArray<Int8Type>>()
570 .unwrap(),
571 &expected
572 );
573 }
574
575 #[test]
576 fn test_binary_mut() {
577 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
578 let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
579 let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
580
581 let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
582 assert_eq!(c, expected);
583 }
584
585 #[test]
586 fn test_binary_mut_null_buffer() {
587 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
588
589 let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
590
591 let r1 = binary_mut(a, &b, |a, b| a + b).unwrap();
592
593 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
594 let b = Int32Array::new(
595 vec![10, 11, 12, 13, 14].into(),
596 Some(vec![true, true, true, true, true].into()),
597 );
598
599 let r2 = binary_mut(a, &b, |a, b| a + b).unwrap();
601 assert_eq!(r1.unwrap(), r2.unwrap());
602 }
603
604 #[test]
605 fn test_try_binary_mut() {
606 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
607 let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
608 let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
609
610 let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
611 assert_eq!(c, expected);
612
613 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
614 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
615 let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
616 let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
617 assert_eq!(c, expected);
618
619 let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
620 let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
621 let _ = try_binary_mut(a, &b, |l, r| {
622 if l == 1 {
623 Err(ArrowError::InvalidArgumentError(
624 "got error".parse().unwrap(),
625 ))
626 } else {
627 Ok(l + r)
628 }
629 })
630 .unwrap()
631 .expect_err("should got error");
632 }
633
634 #[test]
635 fn test_try_binary_mut_null_buffer() {
636 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
637
638 let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
639
640 let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
641
642 let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
643 let b = Int32Array::new(
644 vec![10, 11, 12, 13, 14].into(),
645 Some(vec![true, true, true, true, true].into()),
646 );
647
648 let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
650 assert_eq!(r1.unwrap(), r2.unwrap());
651 }
652
653 #[test]
654 fn test_unary_dict_mut() {
655 let values = Int32Array::from(vec![Some(10), Some(20), None]);
656 let keys = Int8Array::from_iter_values([0, 0, 1, 2]);
657 let dictionary = DictionaryArray::new(keys, Arc::new(values));
658
659 let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap();
660 let typed = updated.downcast_dict::<Int32Array>().unwrap();
661 assert_eq!(typed.value(0), 11);
662 assert_eq!(typed.value(1), 11);
663 assert_eq!(typed.value(2), 21);
664
665 let values = updated.values();
666 assert!(values.is_null(2));
667 }
668}