diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 2867f61da7..02b544e2d5 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -134,6 +134,7 @@ jobs: org.apache.comet.CometCastSuite org.apache.comet.CometExpressionSuite org.apache.comet.CometExpressionCoverageSuite + org.apache.comet.CometMathExpressionSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.comet.CometStringExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 0fd1cb6066..3a1b82d044 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -99,6 +99,7 @@ jobs: org.apache.comet.CometCastSuite org.apache.comet.CometExpressionSuite org.apache.comet.CometExpressionCoverageSuite + org.apache.comet.CometMathExpressionSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.comet.CometStringExpressionSuite diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 537d0d7748..6caaa53b1b 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -164,6 +164,7 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| +| `spark.comet.expression.Abs.enabled` | Enable Comet acceleration for `Abs` | true | | `spark.comet.expression.Acos.enabled` | Enable Comet acceleration for `Acos` | true | | `spark.comet.expression.Add.enabled` | Enable Comet acceleration for `Add` | true | | `spark.comet.expression.Alias.enabled` | Enable Comet acceleration for `Alias` | true | diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 3ccead03a1..809e69d2f8 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -118,6 +118,7 @@ incompatible expressions. | Expression | SQL | Spark-Compatible? | Compatibility Notes | |----------------|-----------|-------------------|-----------------------------------| +| Abs | `abs` | Yes | | | Acos | `acos` | Yes | | | Add | `+` | Yes | | | Asin | `asin` | Yes | | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index a37f928e97..a33df705b3 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -674,19 +674,6 @@ impl PhysicalPlanner { let op = DataFusionOperator::BitwiseShiftLeft; Ok(Arc::new(BinaryExpr::new(left, op, right))) } - // https://github.com/apache/datafusion-comet/issues/666 - // ExprStruct::Abs(expr) => { - // let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; - // let return_type = child.data_type(&input_schema)?; - // let args = vec![child]; - // let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - // let comet_abs = Arc::new(ScalarUDF::new_from_impl(Abs::new( - // eval_mode, - // return_type.to_string(), - // )?)); - // let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type); - // Ok(Arc::new(expr)) - // } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when .when diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5853bc613c..c9037dcd69 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -70,7 +70,6 @@ message Expr { IfExpr if = 44; NormalizeNaNAndZero normalize_nan_and_zero = 45; TruncTimestamp truncTimestamp = 47; - Abs abs = 49; Subquery subquery = 50; UnboundReference unbound = 51; BloomFilterMightContain bloom_filter_might_contain = 52; @@ -351,11 +350,6 @@ message TruncTimestamp { string timezone = 3; } -message Abs { - Expr child = 1; - EvalMode eval_mode = 2; -} - message Subquery { int64 id = 1; DataType datatype = 2; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index fc0c096b15..021bb1c78f 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ @@ -180,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_modulo); make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error) } + "abs" => { + let func = Arc::new(abs); + make_comet_scalar_udf!("abs", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs new file mode 100644 index 0000000000..5a16398ec4 --- /dev/null +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -0,0 +1,890 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arithmetic_overflow_error; +use arrow::array::*; +use arrow::datatypes::*; +use arrow::error::ArrowError; +use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +macro_rules! legacy_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); + Ok(res) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for abs" + ))), + } + }}; +} + +macro_rules! ansi_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident, $FROM_TYPE:expr) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + match arrow::compute::kernels::arity::try_unary(array, |x| { + if x == $NATIVE::MIN { + Err(ArrowError::ArithmeticOverflow($FROM_TYPE.to_string())) + } else { + Ok(x.$FUNC()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::>::new( + res, + ))), + Err(_) => Err(arithmetic_overflow_error($FROM_TYPE).into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + }}; +} + +/// This function mimics SparkSQL's [Abs]: https://github.com/apache/spark/blob/v4.0.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala#L148 +/// Spark's [ANSI-compliant]: https://spark.apache.org/docs/latest/sql-ref-ansi-compliance.html#arithmetic-operations dialect mode throws org.apache.spark.SparkArithmeticException +/// when abs causes overflow. +pub fn abs(args: &[ColumnarValue]) -> Result { + if args.is_empty() || args.len() > 2 { + return exec_err!("abs takes 1 or 2 arguments, but got: {}", args.len()); + } + + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(args[0].clone()), + DataType::Int8 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int8Array, Int8Type, i8, "Int8") + } + } + DataType::Int16 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int16Array, Int16Type, i16, "Int16") + } + } + DataType::Int32 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int32Array, Int32Type, i32, "Int32") + } + } + DataType::Int64 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int64Array, Int64Type, i64, "Int64") + } + } + DataType::Float32 => { + let result = legacy_compute_op!(array, abs, Float32Array, Float32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Float64 => { + let result = legacy_compute_op!(array, abs, Float64Array, Float64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Decimal128(precision, scale) => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Decimal128Array, Decimal128Array)?; + let result = result.with_data_type(DataType::Decimal128(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i128::MIN { + Err(ArrowError::ArithmeticOverflow("Decimal128".to_string())) + } else { + Ok(x.abs()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal128(*precision, *scale)), + ))), + Err(_) => Err(arithmetic_overflow_error("Decimal128").into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + } + } + DataType::Decimal256(precision, scale) => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Decimal256Array, Decimal256Array)?; + let result = result.with_data_type(DataType::Decimal256(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i256::MIN { + Err(ArrowError::ArithmeticOverflow("Decimal256".to_string())) + } else { + Ok(x.wrapping_abs()) // i256 doesn't define abs() method + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal256(*precision, *scale)), + ))), + Err(_) => Err(arithmetic_overflow_error("Decimal256").into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + } + } + dt => exec_err!("Not supported datatype for ABS: {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) => Ok(args[0].clone()), + ScalarValue::Int8(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int8").into()) + } + } + }, + }, + ScalarValue::Int16(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int16").into()) + } + } + }, + }, + ScalarValue::Int32(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int32").into()) + } + } + }, + }, + ScalarValue::Int64(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int64").into()) + } + } + }, + }, + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float32( + a.map(|x| x.abs()), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float64( + a.map(|x| x.abs()), + ))), + ScalarValue::Decimal128(a, precision, scale) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(abs_val), + *precision, + *scale, + ))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(*v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal128").into()) + } + } + }, + }, + ScalarValue::Decimal256(a, precision, scale) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(abs_val), + *precision, + *scale, + ))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(*v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal256").into()) + } + } + }, + }, + dt => exec_err!("Not supported datatype for ABS: {dt}"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::cast::{ + as_decimal128_array, as_decimal256_array, as_float32_array, as_float64_array, + as_int16_array, as_int32_array, as_int64_array, as_int8_array, as_uint64_array, + }; + + fn with_fail_on_error Result<()>>(test_fn: F) { + for fail_on_error in [true, false] { + test_fn(fail_on_error).expect("test should pass on error successfully"); + } + } + + // Unsigned types, return as is + #[test] + fn test_abs_u8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) => { + assert_eq!(result, u8::MAX); + Ok(()) + } + Err(e) => { + unreachable!("Didn't expect error, but got: {e:?}") + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int8(Some(i8::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) => { + assert_eq!(result, i8::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + unreachable!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int16(Some(i16::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) => { + assert_eq!(result, i16::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + unreachable!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int32(Some(i32::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) => { + assert_eq!(result, i32::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) => { + assert_eq!(result, i64::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Decimal128(Some(i128::MIN), 18, 10)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, i128::MIN); + assert_eq!(precision, 18); + assert_eq!(scale, 10); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Decimal256(Some(i256::MIN), 10, 2)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, i256::MIN); + assert_eq!(precision, 10); + assert_eq!(scale, 2); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_array() { + with_fail_on_error(|fail_on_error| { + let input = Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int8_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_array() { + with_fail_on_error(|fail_on_error| { + let input = Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int16_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_array() { + with_fail_on_error(|fail_on_error| { + let input = Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_array() { + with_fail_on_error(|fail_on_error| { + let input = Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f32_array() { + with_fail_on_error(|fail_on_error| { + let input = Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::NEG_INFINITY), + Some(f32::INFINITY), + Some(-0.0), + Some(0.0), + ]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f64_array() { + with_fail_on_error(|fail_on_error| { + let input = Float64Array::from(vec![Some(-1f64), Some(f64::MIN), Some(f64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Float64Array::from(vec![Some(1f64), Some(f64::MAX), Some(f64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal128_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal256_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_u64_array() { + with_fail_on_error(|fail_on_error| { + let input = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_uint64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_null_scalars() { + // Test that NULL scalars return NULL (no panic) for all signed types + with_fail_on_error(|fail_on_error| { + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + // Test Int8 + let args = ColumnarValue::Scalar(ScalarValue::Int8(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(None))) => {} + _ => panic!("Expected NULL Int8, got different result"), + } + + // Test Int16 + let args = ColumnarValue::Scalar(ScalarValue::Int16(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(None))) => {} + _ => panic!("Expected NULL Int16, got different result"), + } + + // Test Int32 + let args = ColumnarValue::Scalar(ScalarValue::Int32(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))) => {} + _ => panic!("Expected NULL Int32, got different result"), + } + + // Test Int64 + let args = ColumnarValue::Scalar(ScalarValue::Int64(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))) => {} + _ => panic!("Expected NULL Int64, got different result"), + } + + // Test Decimal128 + let args = ColumnarValue::Scalar(ScalarValue::Decimal128(None, 10, 2)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(None, 10, 2))) => {} + _ => panic!("Expected NULL Decimal128, got different result"), + } + + // Test Decimal256 + let args = ColumnarValue::Scalar(ScalarValue::Decimal256(None, 10, 2)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(None, 10, 2))) => {} + _ => panic!("Expected NULL Decimal256, got different result"), + } + + // Test Float32 + let args = ColumnarValue::Scalar(ScalarValue::Float32(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(None))) => {} + _ => panic!("Expected NULL Float32, got different result"), + } + + // Test Float64 + let args = ColumnarValue::Scalar(ScalarValue::Float64(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) => {} + _ => panic!("Expected NULL Float64, got different result"), + } + + Ok(()) + }); + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 873b290ebd..7df87eb9f2 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub(crate) mod abs; mod ceil; pub(crate) mod checked_arithmetic; mod div; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 570c07cb09..63e18c145a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -137,7 +137,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Subtract] -> CometSubtract, classOf[Tan] -> CometScalarFunction("tan"), classOf[UnaryMinus] -> CometUnaryMinus, - classOf[Unhex] -> CometUnhex) + classOf[Unhex] -> CometUnhex, + classOf[Abs] -> CometAbs) private val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[GetMapValue] -> CometMapExtract, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index bfcd242d76..68b6e8d11e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,8 +19,8 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} +import org.apache.spark.sql.types.{DecimalType, NumericType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} @@ -144,6 +144,36 @@ object CometUnhex extends CometExpressionSerde[Unhex] with MathExprBase { } } +object CometAbs extends CometExpressionSerde[Abs] with MathExprBase { + + override def getSupportLevel(expr: Abs): SupportLevel = { + expr.child.dataType match { + case _: NumericType => + Compatible() + case _ => + // Spark supports NumericType, DayTimeIntervalType, and YearMonthIntervalType + Unsupported(Some("Only integral, floating-point, and decimal types are supported")) + } + } + + override def convert( + expr: Abs, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val failOnErrorExpr = exprToProtoInternal(Literal(expr.failOnError), inputs, binding) + + val optExpr = + scalarFunctionExprToProtoWithReturnType( + "abs", + expr.dataType, + false, + childExpr, + failOnErrorExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + sealed trait MathExprBase { protected def nullIfNegative(expression: Expression): Expression = { val zero = Literal.default(expression.dataType) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 7b6ed19452..d502749380 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -21,8 +21,6 @@ package org.apache.comet import java.time.{Duration, Period} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.scalactic.source.Position @@ -1430,74 +1428,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { testDoubleScalarExpr("expm1") } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - withParquetTable(path.toString, "tbl") { - Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col => - checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") - } - } - } - } - } - - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs Overflow ansi mode") { - - def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetTable(data, "tbl") { - checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { - case (Some(sparkExc), Some(cometExc)) => - val cometErrorPattern = - """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r - assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) - assert(sparkExc.getMessage.contains("overflow")) - case _ => fail("Exception should be thrown") - } - } - } - - def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") - } - } - - withSQLConf( - SQLConf.ANSI_ENABLED.key -> "true", - CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { - testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue))) - testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue))) - testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue))) - testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue))) - testAbsAnsi(Seq((Float.MaxValue, Float.MinValue))) - testAbsAnsi(Seq((Double.MaxValue, Double.MinValue))) - } - } - - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs Overflow legacy mode") { - - def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") - } - } - } - - testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue))) - testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue))) - testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue))) - testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue))) - testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue))) - testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue))) - } - test("ceil and floor") { Seq("true", "false").foreach { dictionary => withSQLConf( diff --git a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala new file mode 100644 index 0000000000..c95047a0ef --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.util.Random + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} + +class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + test("abs") { + val df = createTestData(generateNegativeZero = false) + df.createOrReplaceTempView("tbl") + for (field <- df.schema.fields) { + val col = field.name + checkSparkAnswerAndOperator(s"SELECT $col, abs($col) FROM tbl ORDER BY $col") + } + } + + test("abs - negative zero") { + val df = createTestData(generateNegativeZero = true) + df.createOrReplaceTempView("tbl") + for (field <- df.schema.fields.filter(f => + f.dataType == DataTypes.FloatType || f.dataType == DataTypes.DoubleType)) { + val col = field.name + checkSparkAnswerAndOperator( + s"SELECT $col, abs($col) FROM tbl WHERE CAST($col as string) = '-0.0' ORDER BY $col") + } + } + + test("abs (ANSI mode)") { + val df = createTestData(generateNegativeZero = false) + df.createOrReplaceTempView("tbl") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + for (field <- df.schema.fields) { + val col = field.name + checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY $col")) match { + case (Some(sparkExc), Some(cometExc)) => + val cometErrorPattern = + """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r + assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) + assert(sparkExc.getMessage.contains("overflow")) + case (Some(_), None) => + fail("Exception should be thrown") + case (None, Some(cometExc)) => + throw cometExc + case _ => + } + } + } + } + + private def createTestData(generateNegativeZero: Boolean) = { + val r = new Random(42) + val schema = StructType( + Seq( + StructField("c0", DataTypes.ByteType, nullable = true), + StructField("c1", DataTypes.ShortType, nullable = true), + StructField("c2", DataTypes.IntegerType, nullable = true), + StructField("c3", DataTypes.LongType, nullable = true), + StructField("c4", DataTypes.FloatType, nullable = true), + StructField("c5", DataTypes.DoubleType, nullable = true), + StructField("c6", DataTypes.createDecimalType(10, 2), nullable = true))) + FuzzDataGenerator.generateDataFrame( + r, + spark, + schema, + 1000, + DataGenOptions(generateNegativeZero = generateNegativeZero)) + } +}