Skip to content

Commit

Permalink
Use checked division kernel (#6792)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jun 29, 2023
1 parent 06e22a5 commit 283b8a1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1670,9 +1670,12 @@ mod tests {
fn test_simplify_divide_zero_by_zero() {
// 0 / 0 -> null
let expr = lit(0) / lit(0);
let expected = lit(ScalarValue::Int32(None));
let err = try_simplify(expr).unwrap_err();

assert_eq!(simplify(expr), expected);
assert!(
matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)),
"{err}"
);
}

#[test]
Expand Down
44 changes: 30 additions & 14 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{any::Any, sync::Arc};

use arrow::array::*;
use arrow::compute::kernels::arithmetic::{
add_dyn, add_scalar_dyn as add_dyn_scalar, divide_dyn_opt,
add_dyn, add_scalar_dyn as add_dyn_scalar, divide_dyn_checked,
divide_scalar_dyn as divide_dyn_scalar, modulus_dyn,
modulus_scalar_dyn as modulus_dyn_scalar, multiply_dyn,
multiply_scalar_dyn as multiply_dyn_scalar, subtract_dyn,
Expand Down Expand Up @@ -63,7 +63,7 @@ use kernels::{
};
use kernels_arrow::{
add_decimal_dyn_scalar, add_dyn_decimal, add_dyn_temporal, divide_decimal_dyn_scalar,
divide_dyn_opt_decimal, is_distinct_from, is_distinct_from_binary,
divide_dyn_checked_decimal, is_distinct_from, is_distinct_from_binary,
is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_f32,
is_distinct_from_f64, is_distinct_from_null, is_distinct_from_utf8,
is_not_distinct_from, is_not_distinct_from_binary, is_not_distinct_from_bool,
Expand Down Expand Up @@ -1223,7 +1223,12 @@ impl BinaryExpr {
binary_primitive_array_op_dyn!(left, right, multiply_dyn, result_type)
}
Divide => {
binary_primitive_array_op_dyn!(left, right, divide_dyn_opt, result_type)
binary_primitive_array_op_dyn!(
left,
right,
divide_dyn_checked,
result_type
)
}
Modulo => {
binary_primitive_array_op_dyn!(left, right, modulus_dyn, result_type)
Expand Down Expand Up @@ -1342,6 +1347,7 @@ mod tests {
use arrow::datatypes::{
ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
};
use arrow_schema::ArrowError;
use datafusion_common::{ColumnStatistics, Result, Statistics};
use datafusion_expr::type_coercion::binary::get_input_types;

Expand Down Expand Up @@ -4336,27 +4342,31 @@ mod tests {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));
let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048, 100]));
let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32, 0]));
let a = Arc::new(Int32Array::from(vec![100]));
let b = Arc::new(Int32Array::from(vec![0]));

apply_arithmetic::<Int32Type>(
let err = apply_arithmetic::<Int32Type>(
schema,
vec![a, b],
Operator::Divide,
Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64), None]),
)?;
Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]),
)
.unwrap_err();

assert!(
matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)),
"{err}"
);

// decimal
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Decimal128(25, 3), true),
Field::new("b", DataType::Decimal128(25, 3), true),
]));
let left_decimal_array =
Arc::new(create_decimal_array(&[Some(1234567), Some(1234567)], 25, 3));
let right_decimal_array =
Arc::new(create_decimal_array(&[Some(10), Some(0)], 25, 3));
let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3));
let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3));

apply_arithmetic::<Decimal128Type>(
let err = apply_arithmetic::<Decimal128Type>(
schema,
vec![left_decimal_array, right_decimal_array],
Operator::Divide,
Expand All @@ -4365,7 +4375,13 @@ mod tests {
38,
29,
),
)?;
)
.unwrap_err();

assert!(
matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)),
"{err}"
);

Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! destined for arrow-rs but are in datafusion until they are ported.

use arrow::compute::{
add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus_dyn,
add_dyn, add_scalar_dyn, divide_dyn_checked, divide_scalar_dyn, modulus_dyn,
modulus_scalar_dyn, multiply_dyn, multiply_fixed_point, multiply_scalar_dyn,
subtract_dyn, subtract_scalar_dyn, try_unary,
};
Expand Down Expand Up @@ -847,7 +847,7 @@ pub(crate) fn multiply_dyn_decimal(
decimal_array_with_precision_scale(array, precision, scale)
}

pub(crate) fn divide_dyn_opt_decimal(
pub(crate) fn divide_dyn_checked_decimal(
left: &dyn Array,
right: &dyn Array,
result_type: &DataType,
Expand All @@ -860,7 +860,7 @@ pub(crate) fn divide_dyn_opt_decimal(
// Restore to original precision and scale (metadata only)
let (org_precision, org_scale) = get_precision_scale(right.data_type())?;
let array = decimal_array_with_precision_scale(array, org_precision, org_scale)?;
let array = divide_dyn_opt(&array, right)?;
let array = divide_dyn_checked(&array, right)?;
decimal_array_with_precision_scale(array, precision, scale)
}

Expand Down Expand Up @@ -2352,7 +2352,7 @@ mod tests {
25,
3,
);
let result = divide_dyn_opt_decimal(
let result = divide_dyn_checked_decimal(
&left_decimal_array,
&right_decimal_array,
&result_type,
Expand Down

0 comments on commit 283b8a1

Please sign in to comment.