arrow_cast/cast/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::cast::*;
19
20/// A utility trait that provides checked conversions between
21/// decimal types inspired by [`NumCast`]
22pub(crate) trait DecimalCast: Sized {
23    fn to_i128(self) -> Option<i128>;
24
25    fn to_i256(self) -> Option<i256>;
26
27    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
28}
29
30impl DecimalCast for i128 {
31    fn to_i128(self) -> Option<i128> {
32        Some(self)
33    }
34
35    fn to_i256(self) -> Option<i256> {
36        Some(i256::from_i128(self))
37    }
38
39    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
40        n.to_i128()
41    }
42}
43
44impl DecimalCast for i256 {
45    fn to_i128(self) -> Option<i128> {
46        self.to_i128()
47    }
48
49    fn to_i256(self) -> Option<i256> {
50        Some(self)
51    }
52
53    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
54        n.to_i256()
55    }
56}
57
58pub(crate) fn cast_decimal_to_decimal_error<I, O>(
59    output_precision: u8,
60    output_scale: i8,
61) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
62where
63    I: DecimalType,
64    O: DecimalType,
65    I::Native: DecimalCast + ArrowNativeTypeOp,
66    O::Native: DecimalCast + ArrowNativeTypeOp,
67{
68    move |x: I::Native| {
69        ArrowError::CastError(format!(
70            "Cannot cast to {}({}, {}). Overflowing on {:?}",
71            O::PREFIX,
72            output_precision,
73            output_scale,
74            x
75        ))
76    }
77}
78
79pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
80    array: &PrimitiveArray<I>,
81    input_scale: i8,
82    output_precision: u8,
83    output_scale: i8,
84    cast_options: &CastOptions,
85) -> Result<PrimitiveArray<O>, ArrowError>
86where
87    I: DecimalType,
88    O: DecimalType,
89    I::Native: DecimalCast + ArrowNativeTypeOp,
90    O::Native: DecimalCast + ArrowNativeTypeOp,
91{
92    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
93    let div = I::Native::from_decimal(10_i128)
94        .unwrap()
95        .pow_checked((input_scale - output_scale) as u32)?;
96
97    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
98    let half_neg = half.neg_wrapping();
99
100    let f = |x: I::Native| {
101        // div is >= 10 and so this cannot overflow
102        let d = x.div_wrapping(div);
103        let r = x.mod_wrapping(div);
104
105        // Round result
106        let adjusted = match x >= I::Native::ZERO {
107            true if r >= half => d.add_wrapping(I::Native::ONE),
108            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
109            _ => d,
110        };
111        O::Native::from_decimal(adjusted)
112    };
113
114    Ok(match cast_options.safe {
115        true => array.unary_opt(f),
116        false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
117    })
118}
119
120pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
121    array: &PrimitiveArray<I>,
122    input_scale: i8,
123    output_precision: u8,
124    output_scale: i8,
125    cast_options: &CastOptions,
126) -> Result<PrimitiveArray<O>, ArrowError>
127where
128    I: DecimalType,
129    O: DecimalType,
130    I::Native: DecimalCast + ArrowNativeTypeOp,
131    O::Native: DecimalCast + ArrowNativeTypeOp,
132{
133    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
134    let mul = O::Native::from_decimal(10_i128)
135        .unwrap()
136        .pow_checked((output_scale - input_scale) as u32)?;
137
138    let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
139
140    Ok(match cast_options.safe {
141        true => array.unary_opt(f),
142        false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
143    })
144}
145
146// Only support one type of decimal cast operations
147pub(crate) fn cast_decimal_to_decimal_same_type<T>(
148    array: &PrimitiveArray<T>,
149    input_scale: i8,
150    output_precision: u8,
151    output_scale: i8,
152    cast_options: &CastOptions,
153) -> Result<ArrayRef, ArrowError>
154where
155    T: DecimalType,
156    T::Native: DecimalCast + ArrowNativeTypeOp,
157{
158    let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
159        Ordering::Equal => {
160            // the scale doesn't change, the native value don't need to be changed
161            array.clone()
162        }
163        Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
164            array,
165            input_scale,
166            output_precision,
167            output_scale,
168            cast_options,
169        )?,
170        Ordering::Less => {
171            // input_scale < output_scale
172            convert_to_bigger_or_equal_scale_decimal::<T, T>(
173                array,
174                input_scale,
175                output_precision,
176                output_scale,
177                cast_options,
178            )?
179        }
180    };
181
182    Ok(Arc::new(array.with_precision_and_scale(
183        output_precision,
184        output_scale,
185    )?))
186}
187
188// Support two different types of decimal cast operations
189pub(crate) fn cast_decimal_to_decimal<I, O>(
190    array: &PrimitiveArray<I>,
191    input_scale: i8,
192    output_precision: u8,
193    output_scale: i8,
194    cast_options: &CastOptions,
195) -> Result<ArrayRef, ArrowError>
196where
197    I: DecimalType,
198    O: DecimalType,
199    I::Native: DecimalCast + ArrowNativeTypeOp,
200    O::Native: DecimalCast + ArrowNativeTypeOp,
201{
202    let array: PrimitiveArray<O> = if input_scale > output_scale {
203        convert_to_smaller_scale_decimal::<I, O>(
204            array,
205            input_scale,
206            output_precision,
207            output_scale,
208            cast_options,
209        )?
210    } else {
211        convert_to_bigger_or_equal_scale_decimal::<I, O>(
212            array,
213            input_scale,
214            output_precision,
215            output_scale,
216            cast_options,
217        )?
218    };
219
220    Ok(Arc::new(array.with_precision_and_scale(
221        output_precision,
222        output_scale,
223    )?))
224}
225
226/// Parses given string to specified decimal native (i128/i256) based on given
227/// scale. Returns an `Err` if it cannot parse given string.
228pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
229    value_str: &str,
230    scale: usize,
231) -> Result<T::Native, ArrowError>
232where
233    T::Native: DecimalCast + ArrowNativeTypeOp,
234{
235    let value_str = value_str.trim();
236    let parts: Vec<&str> = value_str.split('.').collect();
237    if parts.len() > 2 {
238        return Err(ArrowError::InvalidArgumentError(format!(
239            "Invalid decimal format: {value_str:?}"
240        )));
241    }
242
243    let (negative, first_part) = if parts[0].is_empty() {
244        (false, parts[0])
245    } else {
246        match parts[0].as_bytes()[0] {
247            b'-' => (true, &parts[0][1..]),
248            b'+' => (false, &parts[0][1..]),
249            _ => (false, parts[0]),
250        }
251    };
252
253    let integers = first_part;
254    let decimals = if parts.len() == 2 { parts[1] } else { "" };
255
256    if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
257        return Err(ArrowError::InvalidArgumentError(format!(
258            "Invalid decimal format: {value_str:?}"
259        )));
260    }
261
262    if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
263        return Err(ArrowError::InvalidArgumentError(format!(
264            "Invalid decimal format: {value_str:?}"
265        )));
266    }
267
268    // Adjust decimal based on scale
269    let mut number_decimals = if decimals.len() > scale {
270        let decimal_number = i256::from_string(decimals).ok_or_else(|| {
271            ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
272        })?;
273
274        let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
275
276        let half = div.div_wrapping(i256::from_i128(2));
277        let half_neg = half.neg_wrapping();
278
279        let d = decimal_number.div_wrapping(div);
280        let r = decimal_number.mod_wrapping(div);
281
282        // Round result
283        let adjusted = match decimal_number >= i256::ZERO {
284            true if r >= half => d.add_wrapping(i256::ONE),
285            false if r <= half_neg => d.sub_wrapping(i256::ONE),
286            _ => d,
287        };
288
289        let integers = if !integers.is_empty() {
290            i256::from_string(integers)
291                .ok_or_else(|| {
292                    ArrowError::InvalidArgumentError(format!(
293                        "Cannot parse decimal format: {value_str}"
294                    ))
295                })
296                .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
297        } else {
298            i256::ZERO
299        };
300
301        format!("{}", integers.add_wrapping(adjusted))
302    } else {
303        let padding = if scale > decimals.len() { scale } else { 0 };
304
305        let decimals = format!("{decimals:0<padding$}");
306        format!("{integers}{decimals}")
307    };
308
309    if negative {
310        number_decimals.insert(0, '-');
311    }
312
313    let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
314        ArrowError::InvalidArgumentError(format!(
315            "Cannot convert {} to {}: Overflow",
316            value_str,
317            T::PREFIX
318        ))
319    })?;
320
321    T::Native::from_decimal(value).ok_or_else(|| {
322        ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
323    })
324}
325
326pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
327    from: &'a S,
328    precision: u8,
329    scale: i8,
330    cast_options: &CastOptions,
331) -> Result<PrimitiveArray<T>, ArrowError>
332where
333    T: DecimalType,
334    T::Native: DecimalCast + ArrowNativeTypeOp,
335    &'a S: StringArrayType<'a>,
336{
337    if cast_options.safe {
338        let iter = from.iter().map(|v| {
339            v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
340                .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
341        });
342        // Benefit:
343        //     20% performance improvement
344        // Soundness:
345        //     The iterator is trustedLen because it comes from an `StringArray`.
346        Ok(unsafe {
347            PrimitiveArray::<T>::from_trusted_len_iter(iter)
348                .with_precision_and_scale(precision, scale)?
349        })
350    } else {
351        let vec = from
352            .iter()
353            .map(|v| {
354                v.map(|v| {
355                    parse_string_to_decimal_native::<T>(v, scale as usize)
356                        .map_err(|_| {
357                            ArrowError::CastError(format!(
358                                "Cannot cast string '{}' to value of {:?} type",
359                                v,
360                                T::DATA_TYPE,
361                            ))
362                        })
363                        .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
364                })
365                .transpose()
366            })
367            .collect::<Result<Vec<_>, _>>()?;
368        // Benefit:
369        //     20% performance improvement
370        // Soundness:
371        //     The iterator is trustedLen because it comes from an `StringArray`.
372        Ok(unsafe {
373            PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
374                .with_precision_and_scale(precision, scale)?
375        })
376    }
377}
378
379pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
380    from: &GenericStringArray<Offset>,
381    precision: u8,
382    scale: i8,
383    cast_options: &CastOptions,
384) -> Result<PrimitiveArray<T>, ArrowError>
385where
386    T: DecimalType,
387    T::Native: DecimalCast + ArrowNativeTypeOp,
388{
389    generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
390        from,
391        precision,
392        scale,
393        cast_options,
394    )
395}
396
397pub(crate) fn string_view_to_decimal_cast<T>(
398    from: &StringViewArray,
399    precision: u8,
400    scale: i8,
401    cast_options: &CastOptions,
402) -> Result<PrimitiveArray<T>, ArrowError>
403where
404    T: DecimalType,
405    T::Native: DecimalCast + ArrowNativeTypeOp,
406{
407    generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
408}
409
410/// Cast Utf8 to decimal
411pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
412    from: &dyn Array,
413    precision: u8,
414    scale: i8,
415    cast_options: &CastOptions,
416) -> Result<ArrayRef, ArrowError>
417where
418    T: DecimalType,
419    T::Native: DecimalCast + ArrowNativeTypeOp,
420{
421    if scale < 0 {
422        return Err(ArrowError::InvalidArgumentError(format!(
423            "Cannot cast string to decimal with negative scale {scale}"
424        )));
425    }
426
427    if scale > T::MAX_SCALE {
428        return Err(ArrowError::InvalidArgumentError(format!(
429            "Cannot cast string to decimal greater than maximum scale {}",
430            T::MAX_SCALE
431        )));
432    }
433
434    let result = match from.data_type() {
435        DataType::Utf8View => string_view_to_decimal_cast::<T>(
436            from.as_any().downcast_ref::<StringViewArray>().unwrap(),
437            precision,
438            scale,
439            cast_options,
440        )?,
441        DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
442            from.as_any()
443                .downcast_ref::<GenericStringArray<Offset>>()
444                .unwrap(),
445            precision,
446            scale,
447            cast_options,
448        )?,
449        other => {
450            return Err(ArrowError::ComputeError(format!(
451                "Cannot cast {:?} to decimal",
452                other
453            )))
454        }
455    };
456
457    Ok(Arc::new(result))
458}
459
460pub(crate) fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
461    array: &PrimitiveArray<T>,
462    precision: u8,
463    scale: i8,
464    cast_options: &CastOptions,
465) -> Result<ArrayRef, ArrowError>
466where
467    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
468{
469    let mul = 10_f64.powi(scale as i32);
470
471    if cast_options.safe {
472        array
473            .unary_opt::<_, Decimal128Type>(|v| {
474                (mul * v.as_())
475                    .round()
476                    .to_i128()
477                    .filter(|v| Decimal128Type::is_valid_decimal_precision(*v, precision))
478            })
479            .with_precision_and_scale(precision, scale)
480            .map(|a| Arc::new(a) as ArrayRef)
481    } else {
482        array
483            .try_unary::<_, Decimal128Type, _>(|v| {
484                (mul * v.as_())
485                    .round()
486                    .to_i128()
487                    .ok_or_else(|| {
488                        ArrowError::CastError(format!(
489                            "Cannot cast to {}({}, {}). Overflowing on {:?}",
490                            Decimal128Type::PREFIX,
491                            precision,
492                            scale,
493                            v
494                        ))
495                    })
496                    .and_then(|v| {
497                        Decimal128Type::validate_decimal_precision(v, precision).map(|_| v)
498                    })
499            })?
500            .with_precision_and_scale(precision, scale)
501            .map(|a| Arc::new(a) as ArrayRef)
502    }
503}
504
505pub(crate) fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
506    array: &PrimitiveArray<T>,
507    precision: u8,
508    scale: i8,
509    cast_options: &CastOptions,
510) -> Result<ArrayRef, ArrowError>
511where
512    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
513{
514    let mul = 10_f64.powi(scale as i32);
515
516    if cast_options.safe {
517        array
518            .unary_opt::<_, Decimal256Type>(|v| {
519                i256::from_f64((v.as_() * mul).round())
520                    .filter(|v| Decimal256Type::is_valid_decimal_precision(*v, precision))
521            })
522            .with_precision_and_scale(precision, scale)
523            .map(|a| Arc::new(a) as ArrayRef)
524    } else {
525        array
526            .try_unary::<_, Decimal256Type, _>(|v| {
527                i256::from_f64((v.as_() * mul).round())
528                    .ok_or_else(|| {
529                        ArrowError::CastError(format!(
530                            "Cannot cast to {}({}, {}). Overflowing on {:?}",
531                            Decimal256Type::PREFIX,
532                            precision,
533                            scale,
534                            v
535                        ))
536                    })
537                    .and_then(|v| {
538                        Decimal256Type::validate_decimal_precision(v, precision).map(|_| v)
539                    })
540            })?
541            .with_precision_and_scale(precision, scale)
542            .map(|a| Arc::new(a) as ArrayRef)
543    }
544}
545
546pub(crate) fn cast_decimal_to_integer<D, T>(
547    array: &dyn Array,
548    base: D::Native,
549    scale: i8,
550    cast_options: &CastOptions,
551) -> Result<ArrayRef, ArrowError>
552where
553    T: ArrowPrimitiveType,
554    <T as ArrowPrimitiveType>::Native: NumCast,
555    D: DecimalType + ArrowPrimitiveType,
556    <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
557{
558    let array = array.as_primitive::<D>();
559
560    let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
561        ArrowError::CastError(format!(
562            "Cannot cast to {:?}. The scale {} causes overflow.",
563            D::PREFIX,
564            scale,
565        ))
566    })?;
567
568    let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
569
570    if cast_options.safe {
571        for i in 0..array.len() {
572            if array.is_null(i) {
573                value_builder.append_null();
574            } else {
575                let v = array
576                    .value(i)
577                    .div_checked(div)
578                    .ok()
579                    .and_then(<T::Native as NumCast>::from::<D::Native>);
580
581                value_builder.append_option(v);
582            }
583        }
584    } else {
585        for i in 0..array.len() {
586            if array.is_null(i) {
587                value_builder.append_null();
588            } else {
589                let v = array.value(i).div_checked(div)?;
590
591                let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
592                    ArrowError::CastError(format!(
593                        "value of {:?} is out of range {}",
594                        v,
595                        T::DATA_TYPE
596                    ))
597                })?;
598
599                value_builder.append_value(value);
600            }
601        }
602    }
603    Ok(Arc::new(value_builder.finish()))
604}
605
606// Cast the decimal array to floating-point array
607pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
608    array: &dyn Array,
609    op: F,
610) -> Result<ArrayRef, ArrowError>
611where
612    F: Fn(D::Native) -> T::Native,
613{
614    let array = array.as_primitive::<D>();
615    let array = array.unary::<_, T>(op);
616    Ok(Arc::new(array))
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622
623    #[test]
624    fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
625        assert_eq!(
626            parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
627            0_i128
628        );
629        assert_eq!(
630            parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
631            0_i128
632        );
633
634        assert_eq!(
635            parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
636            123_i128
637        );
638        assert_eq!(
639            parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
640            12300000_i128
641        );
642
643        assert_eq!(
644            parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
645            123_i128
646        );
647        assert_eq!(
648            parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
649            12345000_i128
650        );
651
652        assert_eq!(
653            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
654            123_i128
655        );
656        assert_eq!(
657            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
658            12345679_i128
659        );
660        Ok(())
661    }
662}