1use crate::cast::*;
19
20pub(crate) trait DecimalCast: Sized {
23 fn to_i128(self) -> Option<i128>;
24
25 fn to_i256(self) -> Option<i256>;
26
27 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
28}
29
30impl DecimalCast for i128 {
31 fn to_i128(self) -> Option<i128> {
32 Some(self)
33 }
34
35 fn to_i256(self) -> Option<i256> {
36 Some(i256::from_i128(self))
37 }
38
39 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
40 n.to_i128()
41 }
42}
43
44impl DecimalCast for i256 {
45 fn to_i128(self) -> Option<i128> {
46 self.to_i128()
47 }
48
49 fn to_i256(self) -> Option<i256> {
50 Some(self)
51 }
52
53 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
54 n.to_i256()
55 }
56}
57
58pub(crate) fn cast_decimal_to_decimal_error<I, O>(
59 output_precision: u8,
60 output_scale: i8,
61) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
62where
63 I: DecimalType,
64 O: DecimalType,
65 I::Native: DecimalCast + ArrowNativeTypeOp,
66 O::Native: DecimalCast + ArrowNativeTypeOp,
67{
68 move |x: I::Native| {
69 ArrowError::CastError(format!(
70 "Cannot cast to {}({}, {}). Overflowing on {:?}",
71 O::PREFIX,
72 output_precision,
73 output_scale,
74 x
75 ))
76 }
77}
78
79pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
80 array: &PrimitiveArray<I>,
81 input_scale: i8,
82 output_precision: u8,
83 output_scale: i8,
84 cast_options: &CastOptions,
85) -> Result<PrimitiveArray<O>, ArrowError>
86where
87 I: DecimalType,
88 O: DecimalType,
89 I::Native: DecimalCast + ArrowNativeTypeOp,
90 O::Native: DecimalCast + ArrowNativeTypeOp,
91{
92 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
93 let div = I::Native::from_decimal(10_i128)
94 .unwrap()
95 .pow_checked((input_scale - output_scale) as u32)?;
96
97 let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
98 let half_neg = half.neg_wrapping();
99
100 let f = |x: I::Native| {
101 let d = x.div_wrapping(div);
103 let r = x.mod_wrapping(div);
104
105 let adjusted = match x >= I::Native::ZERO {
107 true if r >= half => d.add_wrapping(I::Native::ONE),
108 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
109 _ => d,
110 };
111 O::Native::from_decimal(adjusted)
112 };
113
114 Ok(match cast_options.safe {
115 true => array.unary_opt(f),
116 false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
117 })
118}
119
120pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
121 array: &PrimitiveArray<I>,
122 input_scale: i8,
123 output_precision: u8,
124 output_scale: i8,
125 cast_options: &CastOptions,
126) -> Result<PrimitiveArray<O>, ArrowError>
127where
128 I: DecimalType,
129 O: DecimalType,
130 I::Native: DecimalCast + ArrowNativeTypeOp,
131 O::Native: DecimalCast + ArrowNativeTypeOp,
132{
133 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
134 let mul = O::Native::from_decimal(10_i128)
135 .unwrap()
136 .pow_checked((output_scale - input_scale) as u32)?;
137
138 let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
139
140 Ok(match cast_options.safe {
141 true => array.unary_opt(f),
142 false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
143 })
144}
145
146pub(crate) fn cast_decimal_to_decimal_same_type<T>(
148 array: &PrimitiveArray<T>,
149 input_scale: i8,
150 output_precision: u8,
151 output_scale: i8,
152 cast_options: &CastOptions,
153) -> Result<ArrayRef, ArrowError>
154where
155 T: DecimalType,
156 T::Native: DecimalCast + ArrowNativeTypeOp,
157{
158 let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
159 Ordering::Equal => {
160 array.clone()
162 }
163 Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
164 array,
165 input_scale,
166 output_precision,
167 output_scale,
168 cast_options,
169 )?,
170 Ordering::Less => {
171 convert_to_bigger_or_equal_scale_decimal::<T, T>(
173 array,
174 input_scale,
175 output_precision,
176 output_scale,
177 cast_options,
178 )?
179 }
180 };
181
182 Ok(Arc::new(array.with_precision_and_scale(
183 output_precision,
184 output_scale,
185 )?))
186}
187
188pub(crate) fn cast_decimal_to_decimal<I, O>(
190 array: &PrimitiveArray<I>,
191 input_scale: i8,
192 output_precision: u8,
193 output_scale: i8,
194 cast_options: &CastOptions,
195) -> Result<ArrayRef, ArrowError>
196where
197 I: DecimalType,
198 O: DecimalType,
199 I::Native: DecimalCast + ArrowNativeTypeOp,
200 O::Native: DecimalCast + ArrowNativeTypeOp,
201{
202 let array: PrimitiveArray<O> = if input_scale > output_scale {
203 convert_to_smaller_scale_decimal::<I, O>(
204 array,
205 input_scale,
206 output_precision,
207 output_scale,
208 cast_options,
209 )?
210 } else {
211 convert_to_bigger_or_equal_scale_decimal::<I, O>(
212 array,
213 input_scale,
214 output_precision,
215 output_scale,
216 cast_options,
217 )?
218 };
219
220 Ok(Arc::new(array.with_precision_and_scale(
221 output_precision,
222 output_scale,
223 )?))
224}
225
226pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
229 value_str: &str,
230 scale: usize,
231) -> Result<T::Native, ArrowError>
232where
233 T::Native: DecimalCast + ArrowNativeTypeOp,
234{
235 let value_str = value_str.trim();
236 let parts: Vec<&str> = value_str.split('.').collect();
237 if parts.len() > 2 {
238 return Err(ArrowError::InvalidArgumentError(format!(
239 "Invalid decimal format: {value_str:?}"
240 )));
241 }
242
243 let (negative, first_part) = if parts[0].is_empty() {
244 (false, parts[0])
245 } else {
246 match parts[0].as_bytes()[0] {
247 b'-' => (true, &parts[0][1..]),
248 b'+' => (false, &parts[0][1..]),
249 _ => (false, parts[0]),
250 }
251 };
252
253 let integers = first_part;
254 let decimals = if parts.len() == 2 { parts[1] } else { "" };
255
256 if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
257 return Err(ArrowError::InvalidArgumentError(format!(
258 "Invalid decimal format: {value_str:?}"
259 )));
260 }
261
262 if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
263 return Err(ArrowError::InvalidArgumentError(format!(
264 "Invalid decimal format: {value_str:?}"
265 )));
266 }
267
268 let mut number_decimals = if decimals.len() > scale {
270 let decimal_number = i256::from_string(decimals).ok_or_else(|| {
271 ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
272 })?;
273
274 let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
275
276 let half = div.div_wrapping(i256::from_i128(2));
277 let half_neg = half.neg_wrapping();
278
279 let d = decimal_number.div_wrapping(div);
280 let r = decimal_number.mod_wrapping(div);
281
282 let adjusted = match decimal_number >= i256::ZERO {
284 true if r >= half => d.add_wrapping(i256::ONE),
285 false if r <= half_neg => d.sub_wrapping(i256::ONE),
286 _ => d,
287 };
288
289 let integers = if !integers.is_empty() {
290 i256::from_string(integers)
291 .ok_or_else(|| {
292 ArrowError::InvalidArgumentError(format!(
293 "Cannot parse decimal format: {value_str}"
294 ))
295 })
296 .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
297 } else {
298 i256::ZERO
299 };
300
301 format!("{}", integers.add_wrapping(adjusted))
302 } else {
303 let padding = if scale > decimals.len() { scale } else { 0 };
304
305 let decimals = format!("{decimals:0<padding$}");
306 format!("{integers}{decimals}")
307 };
308
309 if negative {
310 number_decimals.insert(0, '-');
311 }
312
313 let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
314 ArrowError::InvalidArgumentError(format!(
315 "Cannot convert {} to {}: Overflow",
316 value_str,
317 T::PREFIX
318 ))
319 })?;
320
321 T::Native::from_decimal(value).ok_or_else(|| {
322 ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
323 })
324}
325
326pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
327 from: &'a S,
328 precision: u8,
329 scale: i8,
330 cast_options: &CastOptions,
331) -> Result<PrimitiveArray<T>, ArrowError>
332where
333 T: DecimalType,
334 T::Native: DecimalCast + ArrowNativeTypeOp,
335 &'a S: StringArrayType<'a>,
336{
337 if cast_options.safe {
338 let iter = from.iter().map(|v| {
339 v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
340 .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
341 });
342 Ok(unsafe {
347 PrimitiveArray::<T>::from_trusted_len_iter(iter)
348 .with_precision_and_scale(precision, scale)?
349 })
350 } else {
351 let vec = from
352 .iter()
353 .map(|v| {
354 v.map(|v| {
355 parse_string_to_decimal_native::<T>(v, scale as usize)
356 .map_err(|_| {
357 ArrowError::CastError(format!(
358 "Cannot cast string '{}' to value of {:?} type",
359 v,
360 T::DATA_TYPE,
361 ))
362 })
363 .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
364 })
365 .transpose()
366 })
367 .collect::<Result<Vec<_>, _>>()?;
368 Ok(unsafe {
373 PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
374 .with_precision_and_scale(precision, scale)?
375 })
376 }
377}
378
379pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
380 from: &GenericStringArray<Offset>,
381 precision: u8,
382 scale: i8,
383 cast_options: &CastOptions,
384) -> Result<PrimitiveArray<T>, ArrowError>
385where
386 T: DecimalType,
387 T::Native: DecimalCast + ArrowNativeTypeOp,
388{
389 generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
390 from,
391 precision,
392 scale,
393 cast_options,
394 )
395}
396
397pub(crate) fn string_view_to_decimal_cast<T>(
398 from: &StringViewArray,
399 precision: u8,
400 scale: i8,
401 cast_options: &CastOptions,
402) -> Result<PrimitiveArray<T>, ArrowError>
403where
404 T: DecimalType,
405 T::Native: DecimalCast + ArrowNativeTypeOp,
406{
407 generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
408}
409
410pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
412 from: &dyn Array,
413 precision: u8,
414 scale: i8,
415 cast_options: &CastOptions,
416) -> Result<ArrayRef, ArrowError>
417where
418 T: DecimalType,
419 T::Native: DecimalCast + ArrowNativeTypeOp,
420{
421 if scale < 0 {
422 return Err(ArrowError::InvalidArgumentError(format!(
423 "Cannot cast string to decimal with negative scale {scale}"
424 )));
425 }
426
427 if scale > T::MAX_SCALE {
428 return Err(ArrowError::InvalidArgumentError(format!(
429 "Cannot cast string to decimal greater than maximum scale {}",
430 T::MAX_SCALE
431 )));
432 }
433
434 let result = match from.data_type() {
435 DataType::Utf8View => string_view_to_decimal_cast::<T>(
436 from.as_any().downcast_ref::<StringViewArray>().unwrap(),
437 precision,
438 scale,
439 cast_options,
440 )?,
441 DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
442 from.as_any()
443 .downcast_ref::<GenericStringArray<Offset>>()
444 .unwrap(),
445 precision,
446 scale,
447 cast_options,
448 )?,
449 other => {
450 return Err(ArrowError::ComputeError(format!(
451 "Cannot cast {:?} to decimal",
452 other
453 )))
454 }
455 };
456
457 Ok(Arc::new(result))
458}
459
460pub(crate) fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
461 array: &PrimitiveArray<T>,
462 precision: u8,
463 scale: i8,
464 cast_options: &CastOptions,
465) -> Result<ArrayRef, ArrowError>
466where
467 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
468{
469 let mul = 10_f64.powi(scale as i32);
470
471 if cast_options.safe {
472 array
473 .unary_opt::<_, Decimal128Type>(|v| {
474 (mul * v.as_())
475 .round()
476 .to_i128()
477 .filter(|v| Decimal128Type::is_valid_decimal_precision(*v, precision))
478 })
479 .with_precision_and_scale(precision, scale)
480 .map(|a| Arc::new(a) as ArrayRef)
481 } else {
482 array
483 .try_unary::<_, Decimal128Type, _>(|v| {
484 (mul * v.as_())
485 .round()
486 .to_i128()
487 .ok_or_else(|| {
488 ArrowError::CastError(format!(
489 "Cannot cast to {}({}, {}). Overflowing on {:?}",
490 Decimal128Type::PREFIX,
491 precision,
492 scale,
493 v
494 ))
495 })
496 .and_then(|v| {
497 Decimal128Type::validate_decimal_precision(v, precision).map(|_| v)
498 })
499 })?
500 .with_precision_and_scale(precision, scale)
501 .map(|a| Arc::new(a) as ArrayRef)
502 }
503}
504
505pub(crate) fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
506 array: &PrimitiveArray<T>,
507 precision: u8,
508 scale: i8,
509 cast_options: &CastOptions,
510) -> Result<ArrayRef, ArrowError>
511where
512 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
513{
514 let mul = 10_f64.powi(scale as i32);
515
516 if cast_options.safe {
517 array
518 .unary_opt::<_, Decimal256Type>(|v| {
519 i256::from_f64((v.as_() * mul).round())
520 .filter(|v| Decimal256Type::is_valid_decimal_precision(*v, precision))
521 })
522 .with_precision_and_scale(precision, scale)
523 .map(|a| Arc::new(a) as ArrayRef)
524 } else {
525 array
526 .try_unary::<_, Decimal256Type, _>(|v| {
527 i256::from_f64((v.as_() * mul).round())
528 .ok_or_else(|| {
529 ArrowError::CastError(format!(
530 "Cannot cast to {}({}, {}). Overflowing on {:?}",
531 Decimal256Type::PREFIX,
532 precision,
533 scale,
534 v
535 ))
536 })
537 .and_then(|v| {
538 Decimal256Type::validate_decimal_precision(v, precision).map(|_| v)
539 })
540 })?
541 .with_precision_and_scale(precision, scale)
542 .map(|a| Arc::new(a) as ArrayRef)
543 }
544}
545
546pub(crate) fn cast_decimal_to_integer<D, T>(
547 array: &dyn Array,
548 base: D::Native,
549 scale: i8,
550 cast_options: &CastOptions,
551) -> Result<ArrayRef, ArrowError>
552where
553 T: ArrowPrimitiveType,
554 <T as ArrowPrimitiveType>::Native: NumCast,
555 D: DecimalType + ArrowPrimitiveType,
556 <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
557{
558 let array = array.as_primitive::<D>();
559
560 let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
561 ArrowError::CastError(format!(
562 "Cannot cast to {:?}. The scale {} causes overflow.",
563 D::PREFIX,
564 scale,
565 ))
566 })?;
567
568 let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
569
570 if cast_options.safe {
571 for i in 0..array.len() {
572 if array.is_null(i) {
573 value_builder.append_null();
574 } else {
575 let v = array
576 .value(i)
577 .div_checked(div)
578 .ok()
579 .and_then(<T::Native as NumCast>::from::<D::Native>);
580
581 value_builder.append_option(v);
582 }
583 }
584 } else {
585 for i in 0..array.len() {
586 if array.is_null(i) {
587 value_builder.append_null();
588 } else {
589 let v = array.value(i).div_checked(div)?;
590
591 let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
592 ArrowError::CastError(format!(
593 "value of {:?} is out of range {}",
594 v,
595 T::DATA_TYPE
596 ))
597 })?;
598
599 value_builder.append_value(value);
600 }
601 }
602 }
603 Ok(Arc::new(value_builder.finish()))
604}
605
606pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
608 array: &dyn Array,
609 op: F,
610) -> Result<ArrayRef, ArrowError>
611where
612 F: Fn(D::Native) -> T::Native,
613{
614 let array = array.as_primitive::<D>();
615 let array = array.unary::<_, T>(op);
616 Ok(Arc::new(array))
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[test]
624 fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
625 assert_eq!(
626 parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
627 0_i128
628 );
629 assert_eq!(
630 parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
631 0_i128
632 );
633
634 assert_eq!(
635 parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
636 123_i128
637 );
638 assert_eq!(
639 parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
640 12300000_i128
641 );
642
643 assert_eq!(
644 parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
645 123_i128
646 );
647 assert_eq!(
648 parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
649 12345000_i128
650 );
651
652 assert_eq!(
653 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
654 123_i128
655 );
656 assert_eq!(
657 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
658 12345679_i128
659 );
660 Ok(())
661 }
662}