Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1600,10 +1600,16 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
/// Validates values in this array can be properly interpreted
/// with the specified precision.
pub fn validate_decimal_precision(&self, precision: u8) -> Result<(), ArrowError> {
if precision < self.scale() as u8 {
return Err(ArrowError::InvalidArgumentError(format!(
"Decimal precision {precision} is less than scale {}",
self.scale()
)));
}
(0..self.len()).try_for_each(|idx| {
if self.is_valid(idx) {
let decimal = unsafe { self.value_unchecked(idx) };
T::validate_decimal_precision(decimal, precision)
T::validate_decimal_precision(decimal, precision, self.scale())
} else {
Ok(())
}
Expand Down Expand Up @@ -2436,7 +2442,7 @@ mod tests {
let result = arr.validate_decimal_precision(5);
let error = result.unwrap_err();
assert_eq!(
"Invalid argument error: 123456 is too large to store in a Decimal128 of precision 5. Max is 99999",
"Invalid argument error: 123.456 is too large to store in a Decimal128 of precision 5. Max is 99.999",
error.to_string()
);

Expand All @@ -2455,7 +2461,7 @@ mod tests {
let result = arr.validate_decimal_precision(2);
let error = result.unwrap_err();
assert_eq!(
"Invalid argument error: 100 is too large to store in a Decimal128 of precision 2. Max is 99",
"Invalid argument error: 10.0 is too large to store in a Decimal128 of precision 2. Max is 9.9",
error.to_string()
);
}
Expand Down Expand Up @@ -2541,7 +2547,7 @@ mod tests {

#[test]
#[should_panic(
expected = "-123223423432432 is too small to store in a Decimal128 of precision 5. Min is -99999"
expected = "-1232234234324.32 is too small to store in a Decimal128 of precision 5. Min is -999.99"
)]
fn test_decimal_array_with_precision_and_scale_out_of_range() {
let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432])
Expand Down
47 changes: 14 additions & 33 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::timezone::Tz;
use crate::{ArrowNativeTypeOp, OffsetSizeTrait};
use arrow_buffer::{i256, Buffer, OffsetBuffer};
use arrow_data::decimal::{
is_validate_decimal256_precision, is_validate_decimal32_precision,
format_decimal_str, is_validate_decimal256_precision, is_validate_decimal32_precision,
is_validate_decimal64_precision, is_validate_decimal_precision, validate_decimal256_precision,
validate_decimal32_precision, validate_decimal64_precision, validate_decimal_precision,
};
Expand Down Expand Up @@ -1335,7 +1335,11 @@ pub trait DecimalType:
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String;

/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>;
fn validate_decimal_precision(
value: Self::Native,
precision: u8,
scale: i8,
) -> Result<(), ArrowError>;

/// Determines whether `value` contains no more than `precision` decimal digits
fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool;
Expand Down Expand Up @@ -1398,8 +1402,8 @@ impl DecimalType for Decimal32Type {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i32, precision: u8) -> Result<(), ArrowError> {
validate_decimal32_precision(num, precision)
fn validate_decimal_precision(num: i32, precision: u8, scale: i8) -> Result<(), ArrowError> {
validate_decimal32_precision(num, precision, scale)
}

fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
Expand Down Expand Up @@ -1432,8 +1436,8 @@ impl DecimalType for Decimal64Type {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i64, precision: u8) -> Result<(), ArrowError> {
validate_decimal64_precision(num, precision)
fn validate_decimal_precision(num: i64, precision: u8, scale: i8) -> Result<(), ArrowError> {
validate_decimal64_precision(num, precision, scale)
}

fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
Expand Down Expand Up @@ -1466,8 +1470,8 @@ impl DecimalType for Decimal128Type {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
validate_decimal_precision(num, precision)
fn validate_decimal_precision(num: i128, precision: u8, scale: i8) -> Result<(), ArrowError> {
validate_decimal_precision(num, precision, scale)
}

fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
Expand Down Expand Up @@ -1500,8 +1504,8 @@ impl DecimalType for Decimal256Type {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
validate_decimal256_precision(num, precision)
fn validate_decimal_precision(num: i256, precision: u8, scale: i8) -> Result<(), ArrowError> {
validate_decimal256_precision(num, precision, scale)
}

fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
Expand All @@ -1517,29 +1521,6 @@ impl ArrowPrimitiveType for Decimal256Type {

impl primitive::PrimitiveTypeSealed for Decimal256Type {}

fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
Some(stripped) => ("-", stripped),
None => ("", value_str),
};
let bound = precision.min(rest.len()) + sign.len();
let value_str = &value_str[0..bound];

if scale == 0 {
value_str.to_string()
} else if scale < 0 {
let padding = value_str.len() + scale.unsigned_abs() as usize;
format!("{value_str:0<padding$}")
} else if rest.len() > scale as usize {
// Decimal separator is in the middle of the string
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
format!("{whole}.{decimal}")
} else {
// String has to be padded
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
}
}

/// Crate private types for Byte Arrays
///
/// Not intended to be used outside this crate
Expand Down
14 changes: 8 additions & 6 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ where
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x))
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
})?
})
}
Expand Down Expand Up @@ -264,8 +265,9 @@ where
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x))
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
})?
})
}
Expand Down Expand Up @@ -491,7 +493,7 @@ where
T::DATA_TYPE,
))
})
.and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
.and_then(|v| T::validate_decimal_precision(v, precision, scale).map(|_| v))
})
.transpose()
})
Expand Down Expand Up @@ -621,7 +623,7 @@ where
v
))
})
.and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v))
.and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v))
})?
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
Expand Down
Loading
Loading