From 58d8c9a352cf43e0127fb1e3518c624b82a4215d Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:32:23 -0700 Subject: [PATCH 01/23] Removed old ABS implementation --- native/core/src/execution/planner.rs | 13 ------------- native/proto/src/proto/expr.proto | 6 ------ 2 files changed, 19 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 1550efd799..ffa5f86052 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -675,19 +675,6 @@ impl PhysicalPlanner { let op = DataFusionOperator::BitwiseShiftLeft; Ok(Arc::new(BinaryExpr::new(left, op, right))) } - // https://github.com/apache/datafusion-comet/issues/666 - // ExprStruct::Abs(expr) => { - // let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; - // let return_type = child.data_type(&input_schema)?; - // let args = vec![child]; - // let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - // let comet_abs = Arc::new(ScalarUDF::new_from_impl(Abs::new( - // eval_mode, - // return_type.to_string(), - // )?)); - // let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type); - // Ok(Arc::new(expr)) - // } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when .when diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5853bc613c..c9037dcd69 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -70,7 +70,6 @@ message Expr { IfExpr if = 44; NormalizeNaNAndZero normalize_nan_and_zero = 45; TruncTimestamp truncTimestamp = 47; - Abs abs = 49; Subquery subquery = 50; UnboundReference unbound = 51; BloomFilterMightContain bloom_filter_might_contain = 52; @@ -351,11 +350,6 @@ message TruncTimestamp { string timezone = 3; } -message Abs { - Expr child = 1; - EvalMode eval_mode = 2; -} - message Subquery { int64 id = 1; DataType datatype = 2; From d5b70c3275fbe958b3b7a88a35fd78bce49f1bba Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:33:28 -0700 Subject: [PATCH 02/23] Define Comet's ABS in Scala --- .../apache/comet/serde/QueryPlanSerde.scala | 3 ++- .../scala/org/apache/comet/serde/math.scala | 21 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 233261091b..33739d9491 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -137,7 +137,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Subtract] -> CometSubtract, classOf[Tan] -> CometScalarFunction("tan"), classOf[UnaryMinus] -> CometUnaryMinus, - classOf[Unhex] -> CometUnhex) + classOf[Unhex] -> CometUnhex, + classOf[Abs] -> CometAbs) private val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[GetMapValue] -> CometMapExtract, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index bfcd242d76..90de894e3e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} +import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} import org.apache.spark.sql.types.DecimalType import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -144,6 +144,25 @@ object CometUnhex extends CometExpressionSerde[Unhex] with MathExprBase { } } +object CometAbs extends CometExpressionSerde[Abs] with MathExprBase { + override def convert( + expr: Abs, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val failOnErrorExpr = exprToProtoInternal(Literal(expr.failOnError), inputs, binding) + + val optExpr = + scalarFunctionExprToProtoWithReturnType( + "abs", + expr.dataType, + false, + childExpr, + failOnErrorExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + sealed trait MathExprBase { protected def nullIfNegative(expression: Expression): Expression = { val zero = Literal.default(expression.dataType) From 7b1965adb8777974d251809e5451065cce888acf Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:34:32 -0700 Subject: [PATCH 03/23] Implement Comet's ABS in rust --- native/spark-expr/src/comet_scalar_funcs.rs | 5 + native/spark-expr/src/math_funcs/abs.rs | 798 ++++++++++++++++++++ native/spark-expr/src/math_funcs/mod.rs | 1 + 3 files changed, 804 insertions(+) create mode 100644 native/spark-expr/src/math_funcs/abs.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index fc0c096b15..021bb1c78f 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ @@ -180,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_modulo); make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error) } + "abs" => { + let func = Arc::new(abs); + make_comet_scalar_udf!("abs", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs new file mode 100644 index 0000000000..54d21a3bb5 --- /dev/null +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -0,0 +1,798 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arithmetic_overflow_error; +use arrow::array::*; +use arrow::datatypes::*; +use arrow::error::ArrowError; +use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +macro_rules! legacy_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); + Ok(res) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for abs" + ))), + } + }}; +} + +macro_rules! ansi_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident, $FROM_TYPE:expr) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + match arrow::compute::kernels::arity::try_unary(array, |x| { + if x == $NATIVE::MIN { + Err(ArrowError::ArithmeticOverflow($FROM_TYPE.to_string())) + } else { + Ok(x.$FUNC()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::>::new( + res, + ))), + Err(_) => Err(arithmetic_overflow_error($FROM_TYPE).into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + }}; +} + +/// This function mimics SparkSQL's [Abs]: https://github.com/apache/spark/blob/v4.0.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala#L148 +/// Spark's [ANSI-compliant]: https://spark.apache.org/docs/latest/sql-ref-ansi-compliance.html#arithmetic-operations dialect mode throws org.apache.spark.SparkArithmeticException +/// when abs causes overflow. +pub fn abs(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("abs takes at most 2 arguments, but got: {}", args.len()); + } + + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(args[0].clone()), + DataType::Int8 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int8Array, Int8Type, i8, "Int8") + } + } + DataType::Int16 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int16Array, Int16Type, i16, "Int16") + } + } + DataType::Int32 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int32Array, Int32Type, i32, "Int32") + } + } + DataType::Int64 => { + if !fail_on_error { + let result = legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int64Array, Int64Type, i64, "Int64") + } + } + DataType::Float32 => { + let result = legacy_compute_op!(array, abs, Float32Array, Float32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Float64 => { + let result = legacy_compute_op!(array, abs, Float64Array, Float64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Decimal128(precision, scale) => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Decimal128Array, Decimal128Array)?; + let result = result.with_data_type(DataType::Decimal128(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i128::MIN { + Err(ArrowError::ArithmeticOverflow("Decimal128".to_string())) + } else { + Ok(x.abs()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal128(*precision, *scale)), + ))), + Err(_) => Err(arithmetic_overflow_error("Decimal128").into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + } + } + DataType::Decimal256(precision, scale) => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Decimal256Array, Decimal256Array)?; + let result = result.with_data_type(DataType::Decimal256(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i256::MIN { + Err(ArrowError::ArithmeticOverflow("Decimal256".to_string())) + } else { + Ok(x.wrapping_abs()) // i256 doesn't define abs() method + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal256(*precision, *scale)), + ))), + Err(_) => Err(arithmetic_overflow_error("Decimal256").into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + } + } + dt => exec_err!("Not supported datatype for ABS: {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) => Ok(args[0].clone()), + ScalarValue::Int8(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int8").into()) + } + } + }) + .unwrap(), + ScalarValue::Int16(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int16").into()) + } + } + }) + .unwrap(), + ScalarValue::Int32(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int32").into()) + } + } + }) + .unwrap(), + ScalarValue::Int64(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int64").into()) + } + } + }) + .unwrap(), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float32( + a.map(|x| x.abs()), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float64( + a.map(|x| x.abs()), + ))), + ScalarValue::Decimal128(a, precision, scale) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(abs_val), + *precision, + *scale, + ))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal128").into()) + } + } + }) + .unwrap(), + ScalarValue::Decimal256(a, precision, scale) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(abs_val), + *precision, + *scale, + ))), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal256").into()) + } + } + }) + .unwrap(), + dt => exec_err!("Not supported datatype for ABS: {dt}"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::cast::{ + as_decimal128_array, as_decimal256_array, as_float32_array, as_float64_array, + as_int16_array, as_int32_array, as_int64_array, as_int8_array, as_uint64_array, + }; + + fn with_fail_on_error Result<()>>(test_fn: F) { + for fail_on_error in [true, false] { + let _ = test_fn(fail_on_error); + } + } + + // Unsigned types, return as is + #[test] + fn test_abs_u8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) => { + assert_eq!(result, u8::MAX); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int8(Some(i8::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) => { + assert_eq!(result, i8::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int16(Some(i16::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) => { + assert_eq!(result, i16::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int32(Some(i32::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) => { + assert_eq!(result, i32::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) => { + assert_eq!(result, i64::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Decimal128(Some(i128::MIN), 18, 10)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), precision, scale))) => { + assert_eq!(result, i128::MIN); + assert_eq!(precision, 18); + assert_eq!(scale, 10); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Decimal256(Some(i256::MIN), 10, 2)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(Some(result), precision, scale))) => { + assert_eq!(result, i256::MIN); + assert_eq!(precision, 10); + assert_eq!(scale, 2); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_array() { + with_fail_on_error(|fail_on_error| { + let input = Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int8_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_array() { + with_fail_on_error(|fail_on_error| { + let input = Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int16_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_array() { + with_fail_on_error(|fail_on_error| { + let input = Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_array() { + with_fail_on_error(|fail_on_error| { + let input = Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f32_array() { + with_fail_on_error(|fail_on_error| { + let input = Float32Array::from(vec![Some(-1f32), Some(f32::MIN), Some(f32::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Float32Array::from(vec![Some(1f32), Some(f32::MAX), Some(f32::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f64_array() { + with_fail_on_error(|fail_on_error| { + let input = Float64Array::from(vec![Some(-1f64), Some(f64::MIN), Some(f64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Float64Array::from(vec![Some(1f64), Some(f64::MAX), Some(f64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal128_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal256_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_u64_array() { + with_fail_on_error(|fail_on_error| { + let input = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_uint64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 873b290ebd..7df87eb9f2 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub(crate) mod abs; mod ceil; pub(crate) mod checked_arithmetic; mod div; From eb94cb6027a994e669ca4b517f0c98db999d13cd Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:34:49 -0700 Subject: [PATCH 04/23] Enable ABS tests in legacy/ANSI mode --- .../apache/comet/CometExpressionSuite.scala | 27 ++++++++++--------- .../org/apache/spark/sql/CometTestBase.scala | 8 +++--- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 9085c0fa29..f64b6d22b5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1385,23 +1385,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { testDoubleScalarExpr("expm1") } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - withParquetTable(path.toString, "tbl") { - Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col => - checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") + test("abs") { + Seq(true, false).foreach { ansi_enabled => + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi_enabled.toString) { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 100) + withParquetTable(path.toString, "tbl") { + Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col => + checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") + } + } } } } } } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs Overflow ansi mode") { + test("abs Overflow ANSI mode") { def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetTable(data, "tbl") { @@ -1434,8 +1436,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // https://github.com/apache/datafusion-comet/issues/666 - ignore("abs Overflow legacy mode") { + test("abs Overflow legacy mode") { def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 844bd07f3b..900b8a44f5 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -504,10 +504,10 @@ abstract class CometTestBase | optional float _6; | optional double _7; | optional binary _8(UTF8); - | optional int32 _9(UINT_8); - | optional int32 _10(UINT_16); - | optional int32 _11(UINT_32); - | optional int64 _12(UINT_64); + | optional int32 _9(INT_8); + | optional int32 _10(INT_16); + | optional int32 _11(INT_32); + | optional int64 _12(INT_64); | optional binary _13(ENUM); | optional FIXED_LEN_BYTE_ARRAY(3) _14; | optional int32 _15(DECIMAL(5, 2)); From ce1c1c20684d2a4a0386e8bc0686eca59392d2f2 Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Wed, 15 Oct 2025 17:02:14 -0700 Subject: [PATCH 05/23] Fix bit position b/c schema change in CometTestBase --- .../scala/org/apache/comet/CometBitwiseExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala index d89e81b0fd..cf7eb02bff 100644 --- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala @@ -134,7 +134,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe s"bit_get(_3, $shortBitPosition)", s"bit_get(_4, $intBitPosition)", s"bit_get(_5, $longBitPosition)", - s"bit_get(_11, $longBitPosition)")) + s"bit_get(_11, $intBitPosition)")) } } } From 9cb2f6ca208699958459f0cce30a77deecd1cd95 Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Thu, 16 Oct 2025 14:21:58 -0700 Subject: [PATCH 06/23] Update docs --- docs/source/user-guide/latest/configs.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index a299a75738..1b593e4821 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -150,6 +150,7 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| +| `spark.comet.expression.Abs.enabled` | Enable Comet acceleration for `Abs` | true | | `spark.comet.expression.Acos.enabled` | Enable Comet acceleration for `Acos` | true | | `spark.comet.expression.Add.enabled` | Enable Comet acceleration for `Add` | true | | `spark.comet.expression.Alias.enabled` | Enable Comet acceleration for `Alias` | true | From 6016f6993ee151ba8d1fe336b096422f8ce552f8 Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Fri, 17 Oct 2025 13:32:29 -0700 Subject: [PATCH 07/23] Fix style --- native/spark-expr/src/math_funcs/abs.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 54d21a3bb5..78148995dd 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -460,7 +460,11 @@ mod tests { let fail_on_error_arg = ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); match abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), precision, scale))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(result), + precision, + scale, + ))) => { assert_eq!(result, i128::MIN); assert_eq!(precision, 18); assert_eq!(scale, 10); @@ -489,7 +493,11 @@ mod tests { let fail_on_error_arg = ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); match abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(Some(result), precision, scale))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(result), + precision, + scale, + ))) => { assert_eq!(result, i256::MIN); assert_eq!(precision, 10); assert_eq!(scale, 2); From 1ddff9831c6e0b750289723e24e89a08b7c6a39e Mon Sep 17 00:00:00 2001 From: hsiang-c Date: Fri, 17 Oct 2025 13:48:23 -0700 Subject: [PATCH 08/23] Update doc --- docs/source/user-guide/latest/expressions.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 3ccead03a1..809e69d2f8 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -118,6 +118,7 @@ incompatible expressions. | Expression | SQL | Spark-Compatible? | Compatibility Notes | |----------------|-----------|-------------------|-----------------------------------| +| Abs | `abs` | Yes | | | Acos | `acos` | Yes | | | Add | `+` | Yes | | | Asin | `asin` | Yes | | From 3a48453259db08752a9fdb0557704ad6d287d0d1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 19:09:33 -0700 Subject: [PATCH 09/23] new test --- .../scala/org/apache/comet/serde/math.scala | 14 +++- .../apache/comet/CometExpressionSuite.scala | 69 ---------------- .../comet/CometMathExpressionSuite.scala | 80 +++++++++++++++++++ .../org/apache/spark/sql/CometTestBase.scala | 8 +- 4 files changed, 96 insertions(+), 75 deletions(-) create mode 100644 spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index 90de894e3e..36f0af7769 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -20,8 +20,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} -import org.apache.spark.sql.types.DecimalType - +import org.apache.spark.sql.types.{DecimalType, NumericType, ShortType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} @@ -145,6 +144,17 @@ object CometUnhex extends CometExpressionSerde[Unhex] with MathExprBase { } object CometAbs extends CometExpressionSerde[Abs] with MathExprBase { + + override def getSupportLevel(expr: Abs): SupportLevel = { + expr.child.dataType match { + case _: NumericType => + Compatible() + case _ => + // Spark supports NumericType, DayTimeIntervalType, and YearMonthIntervalType + Unsupported(Some("Only integral, floating-point, and decimal types are supported")) + } + } + override def convert( expr: Abs, inputs: Seq[Attribute], diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e43ddb7da2..1262091d39 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1430,75 +1430,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { testDoubleScalarExpr("expm1") } - test("abs") { - Seq(true, false).foreach { ansi_enabled => - Seq(true, false).foreach { dictionaryEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi_enabled.toString) { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - withParquetTable(path.toString, "tbl") { - Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col => - checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") - } - } - } - } - } - } - } - - test("abs Overflow ANSI mode") { - - def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetTable(data, "tbl") { - checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { - case (Some(sparkExc), Some(cometExc)) => - val cometErrorPattern = - """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r - assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) - assert(sparkExc.getMessage.contains("overflow")) - case _ => fail("Exception should be thrown") - } - } - } - - def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") - } - } - - withSQLConf( - SQLConf.ANSI_ENABLED.key -> "true", - CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { - testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue))) - testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue))) - testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue))) - testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue))) - testAbsAnsi(Seq((Float.MaxValue, Float.MinValue))) - testAbsAnsi(Seq((Double.MaxValue, Double.MinValue))) - } - } - - test("abs Overflow legacy mode") { - - def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") - } - } - } - - testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue))) - testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue))) - testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue))) - testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue))) - testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue))) - testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue))) - } - test("ceil and floor") { Seq("true", "false").foreach { dictionary => withSQLConf( diff --git a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala new file mode 100644 index 0000000000..e78de58e0d --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.util.Random +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} +import org.apache.spark.sql.internal.SQLConf + +class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + test("abs") { + val df = createTestData + df.createOrReplaceTempView("tbl") + for (field <- df.schema.fields) { + val col = field.name + checkSparkAnswerAndOperator(s"SELECT $col, abs($col) FROM tbl ORDER BY $col") + } + } + + test("abs (ANSI mode)") { + val df = createTestData + df.createOrReplaceTempView("tbl") + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + for (field <- df.schema.fields) { + val col = field.name + checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY $col")) match { + case (Some(sparkExc), Some(cometExc)) => + val cometErrorPattern = + """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r + assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) + assert(sparkExc.getMessage.contains("overflow")) + case (Some(_), None) => + fail("Exception should be thrown") + case (None, Some(cometExc)) => + throw cometExc + case _ => + } + } + } + } + + private def createTestData = { + val r = new Random(42) + val schema = StructType( + Seq( + StructField("c0", DataTypes.ByteType, nullable = true), + StructField("c1", DataTypes.ShortType, nullable = true), + StructField("c2", DataTypes.IntegerType, nullable = true), + StructField("c3", DataTypes.LongType, nullable = true), + StructField("c4", DataTypes.FloatType, nullable = true), + StructField("c5", DataTypes.DoubleType, nullable = true), + StructField("c6", DataTypes.createDecimalType(10, 2), nullable = true))) + FuzzDataGenerator.generateDataFrame( + r, + spark, + schema, + 1000, + DataGenOptions(generateNegativeZero = false)) + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index bd335e407f..1854edf590 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -621,10 +621,10 @@ abstract class CometTestBase | optional float _6; | optional double _7; | optional binary _8(UTF8); - | optional int32 _9(INT_8); - | optional int32 _10(INT_16); - | optional int32 _11(INT_32); - | optional int64 _12(INT_64); + | optional int32 _9(UINT_8); + | optional int32 _10(UINT_16); + | optional int32 _11(UINT_32); + | optional int64 _12(UINT_64); | optional binary _13(ENUM); | optional FIXED_LEN_BYTE_ARRAY(3) _14; | optional int32 _15(DECIMAL(5, 2)); From 2351dfb3dbe513c887a8063d863aa99ab6dc80fa Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 19:10:33 -0700 Subject: [PATCH 10/23] format --- spark/src/main/scala/org/apache/comet/serde/math.scala | 3 ++- .../test/scala/org/apache/comet/CometExpressionSuite.scala | 2 -- .../scala/org/apache/comet/CometMathExpressionSuite.scala | 4 +++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index 36f0af7769..68b6e8d11e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -20,7 +20,8 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex} -import org.apache.spark.sql.types.{DecimalType, NumericType, ShortType} +import org.apache.spark.sql.types.{DecimalType, NumericType} + import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1262091d39..d502749380 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -21,8 +21,6 @@ package org.apache.comet import java.time.{Duration, Period} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.scalactic.source.Position diff --git a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala index e78de58e0d..6308e77023 100644 --- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala @@ -20,11 +20,13 @@ package org.apache.comet import scala.util.Random + import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} -import org.apache.spark.sql.internal.SQLConf class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { From 4952286e6726aeaa9c2f031381847359200bcbca Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 19:11:42 -0700 Subject: [PATCH 11/23] Revert --- .../scala/org/apache/comet/CometBitwiseExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala index 0e310da928..02c003ede8 100644 --- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala @@ -130,7 +130,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe s"bit_get(_3, $shortBitPosition)", s"bit_get(_4, $intBitPosition)", s"bit_get(_5, $longBitPosition)", - s"bit_get(_11, $intBitPosition)")) + s"bit_get(_11, $longBitPosition)")) } } } From c414db98b57b8d1f8ba61640275a40105825f75c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 19:28:46 -0700 Subject: [PATCH 12/23] negative zero test --- .../comet/CometMathExpressionSuite.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala index 6308e77023..e6ba668c89 100644 --- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala @@ -31,7 +31,7 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("abs") { - val df = createTestData + val df = createTestData(generateNegativeZero = false) df.createOrReplaceTempView("tbl") for (field <- df.schema.fields) { val col = field.name @@ -39,8 +39,19 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe } } + test("abs - negative zero") { + val df = createTestData(generateNegativeZero = true) + df.createOrReplaceTempView("tbl") + for (field <- df.schema.fields.filter(f => + f.dataType == DataTypes.FloatType || f.dataType == DataTypes.DoubleType)) { + val col = field.name + checkSparkAnswerAndOperator( + s"SELECT $col, abs($col) FROM tbl WHERE signum($col) < 0 ORDER BY $col") + } + } + test("abs (ANSI mode)") { - val df = createTestData + val df = createTestData(generateNegativeZero = false) df.createOrReplaceTempView("tbl") withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { for (field <- df.schema.fields) { @@ -61,7 +72,7 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe } } - private def createTestData = { + private def createTestData(generateNegativeZero: Boolean) = { val r = new Random(42) val schema = StructType( Seq( @@ -77,6 +88,6 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe spark, schema, 1000, - DataGenOptions(generateNegativeZero = false)) + DataGenOptions(generateNegativeZero)) } } From 4236ddeab7cfa79df272651fac9079c9fb1c606d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 19:40:14 -0700 Subject: [PATCH 13/23] fix null scalar issue --- native/spark-expr/src/math_funcs/abs.rs | 133 ++++++++++++++++++------ 1 file changed, 103 insertions(+), 30 deletions(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 78148995dd..9313b50fe1 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -198,66 +198,71 @@ pub fn abs(args: &[ColumnarValue]) -> Result { | ScalarValue::UInt16(_) | ScalarValue::UInt32(_) | ScalarValue::UInt64(_) => Ok(args[0].clone()), - ScalarValue::Int8(a) => a - .map(|v| match v.checked_abs() { + ScalarValue::Int8(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))), None => { if !fail_on_error { // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(v)))) + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(*v)))) } else { Err(arithmetic_overflow_error("Int8").into()) } } - }) - .unwrap(), - ScalarValue::Int16(a) => a - .map(|v| match v.checked_abs() { + }, + }, + ScalarValue::Int16(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))), None => { if !fail_on_error { // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(v)))) + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(*v)))) } else { Err(arithmetic_overflow_error("Int16").into()) } } - }) - .unwrap(), - ScalarValue::Int32(a) => a - .map(|v| match v.checked_abs() { + }, + }, + ScalarValue::Int32(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))), None => { if !fail_on_error { // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(v)))) + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(*v)))) } else { Err(arithmetic_overflow_error("Int32").into()) } } - }) - .unwrap(), - ScalarValue::Int64(a) => a - .map(|v| match v.checked_abs() { + }, + }, + ScalarValue::Int64(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))), None => { if !fail_on_error { // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(*v)))) } else { Err(arithmetic_overflow_error("Int64").into()) } } - }) - .unwrap(), + }, + }, ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float32( a.map(|x| x.abs()), ))), ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Float64( a.map(|x| x.abs()), ))), - ScalarValue::Decimal128(a, precision, scale) => a - .map(|v| match v.checked_abs() { + ScalarValue::Decimal128(a, precision, scale) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( Some(abs_val), *precision, @@ -267,7 +272,7 @@ pub fn abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { // return the original value Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(v), + Some(*v), *precision, *scale, ))) @@ -275,10 +280,11 @@ pub fn abs(args: &[ColumnarValue]) -> Result { Err(arithmetic_overflow_error("Decimal128").into()) } } - }) - .unwrap(), - ScalarValue::Decimal256(a, precision, scale) => a - .map(|v| match v.checked_abs() { + }, + }, + ScalarValue::Decimal256(a, precision, scale) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( Some(abs_val), *precision, @@ -288,7 +294,7 @@ pub fn abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { // return the original value Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( - Some(v), + Some(*v), *precision, *scale, ))) @@ -296,8 +302,8 @@ pub fn abs(args: &[ColumnarValue]) -> Result { Err(arithmetic_overflow_error("Decimal256").into()) } } - }) - .unwrap(), + }, + }, dt => exec_err!("Not supported datatype for ABS: {dt}"), }, } @@ -803,4 +809,71 @@ mod tests { } }); } + + #[test] + fn test_abs_null_scalars() { + // Test that NULL scalars return NULL (no panic) for all signed types + with_fail_on_error(|fail_on_error| { + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + // Test Int8 + let args = ColumnarValue::Scalar(ScalarValue::Int8(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(None))) => {} + _ => panic!("Expected NULL Int8, got different result"), + } + + // Test Int16 + let args = ColumnarValue::Scalar(ScalarValue::Int16(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(None))) => {} + _ => panic!("Expected NULL Int16, got different result"), + } + + // Test Int32 + let args = ColumnarValue::Scalar(ScalarValue::Int32(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))) => {} + _ => panic!("Expected NULL Int32, got different result"), + } + + // Test Int64 + let args = ColumnarValue::Scalar(ScalarValue::Int64(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))) => {} + _ => panic!("Expected NULL Int64, got different result"), + } + + // Test Decimal128 + let args = ColumnarValue::Scalar(ScalarValue::Decimal128(None, 10, 2)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(None, 10, 2))) => {} + _ => panic!("Expected NULL Decimal128, got different result"), + } + + // Test Decimal256 + let args = ColumnarValue::Scalar(ScalarValue::Decimal256(None, 10, 2)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(None, 10, 2))) => {} + _ => panic!("Expected NULL Decimal256, got different result"), + } + + // Test Float32 + let args = ColumnarValue::Scalar(ScalarValue::Float32(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(None))) => {} + _ => panic!("Expected NULL Float32, got different result"), + } + + // Test Float64 + let args = ColumnarValue::Scalar(ScalarValue::Float64(None)); + match abs(&[args.clone(), fail_on_error_arg.clone()]) { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) => {} + _ => panic!("Expected NULL Float64, got different result"), + } + + Ok(()) + }); + } } From 0c1b1728ab80cba4e01f71a4e418fa8ddc67f938 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 19:47:29 -0700 Subject: [PATCH 14/23] add new test to CI --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 2867f61da7..02b544e2d5 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -134,6 +134,7 @@ jobs: org.apache.comet.CometCastSuite org.apache.comet.CometExpressionSuite org.apache.comet.CometExpressionCoverageSuite + org.apache.comet.CometMathExpressionSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.comet.CometStringExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 0fd1cb6066..3a1b82d044 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -99,6 +99,7 @@ jobs: org.apache.comet.CometCastSuite org.apache.comet.CometExpressionSuite org.apache.comet.CometExpressionCoverageSuite + org.apache.comet.CometMathExpressionSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.comet.CometStringExpressionSuite From 7fcfdd3ffefb17d9b4d06447d32c9fa121d1f177 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Nov 2025 20:53:09 -0700 Subject: [PATCH 15/23] fix --- .../test/scala/org/apache/comet/CometMathExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala index e6ba668c89..9428dae72f 100644 --- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala @@ -88,6 +88,6 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe spark, schema, 1000, - DataGenOptions(generateNegativeZero)) + DataGenOptions(generateNegativeZero = generateNegativeZero)) } } From abddd9eb8b285b45ca8ef920648c21094ec70ba1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 08:02:02 -0700 Subject: [PATCH 16/23] fix bug in negative zero test --- .../test/scala/org/apache/comet/CometMathExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala index 9428dae72f..c95047a0ef 100644 --- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala @@ -46,7 +46,7 @@ class CometMathExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe f.dataType == DataTypes.FloatType || f.dataType == DataTypes.DoubleType)) { val col = field.name checkSparkAnswerAndOperator( - s"SELECT $col, abs($col) FROM tbl WHERE signum($col) < 0 ORDER BY $col") + s"SELECT $col, abs($col) FROM tbl WHERE CAST($col as string) = '-0.0' ORDER BY $col") } } From 3089dca4fb2f1c579ff3fe8f92d6db3986654c4c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 08:13:28 -0700 Subject: [PATCH 17/23] address review feedback --- native/spark-expr/src/math_funcs/abs.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 9313b50fe1..11e605cd18 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -65,8 +65,8 @@ macro_rules! ansi_compute_op { /// Spark's [ANSI-compliant]: https://spark.apache.org/docs/latest/sql-ref-ansi-compliance.html#arithmetic-operations dialect mode throws org.apache.spark.SparkArithmeticException /// when abs causes overflow. pub fn abs(args: &[ColumnarValue]) -> Result { - if args.len() > 2 { - return exec_err!("abs takes at most 2 arguments, but got: {}", args.len()); + if args.is_empty() || args.len() > 2 { + return exec_err!("abs takes 1 or 2 arguments, but got: {}", args.len()); } let fail_on_error = if args.len() == 2 { @@ -336,15 +336,7 @@ mod tests { Ok(()) } Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("ARITHMETIC_OVERFLOW"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + panic!("Didn't expect error, but got: {e:?}") } _ => unreachable!(), } From f0e9c6e7df683e4406db55e407759eefeb54f66e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 08:15:26 -0700 Subject: [PATCH 18/23] address feedback --- native/spark-expr/src/math_funcs/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 11e605cd18..250a077e96 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -319,7 +319,7 @@ mod tests { fn with_fail_on_error Result<()>>(test_fn: F) { for fail_on_error in [true, false] { - let _ = test_fn(fail_on_error); + test_fn(fail_on_error).expect("test should pass on error successfully"); } } From 50f899fcbcc22fb1c434b7ea802ec4c4d44b2c76 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 12:40:08 -0700 Subject: [PATCH 19/23] Update native/spark-expr/src/math_funcs/abs.rs Co-authored-by: Oleks V --- native/spark-expr/src/math_funcs/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 250a077e96..2573c7ad2d 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -389,7 +389,7 @@ mod tests { ); Ok(()) } else { - panic!("Didn't expect error, but got: {e:?}") + unreachable!("Didn't expect error, but got: {e:?}") } } _ => unreachable!(), From 25f86ff7bec38cdbfa931d0bdd04672f7f51c77a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 12:42:18 -0700 Subject: [PATCH 20/23] Update native/spark-expr/src/math_funcs/abs.rs Co-authored-by: Oleks V --- native/spark-expr/src/math_funcs/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 2573c7ad2d..1d0ce1632a 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -336,7 +336,7 @@ mod tests { Ok(()) } Err(e) => { - panic!("Didn't expect error, but got: {e:?}") + unreachable!("Didn't expect error, but got: {e:?}") } _ => unreachable!(), } From 437cfdcfc225c7550a9722e2a0ac36ef68f441e8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 12:42:30 -0700 Subject: [PATCH 21/23] Update native/spark-expr/src/math_funcs/abs.rs Co-authored-by: Oleks V --- native/spark-expr/src/math_funcs/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 1d0ce1632a..64cb2f13cb 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -362,7 +362,7 @@ mod tests { ); Ok(()) } else { - panic!("Didn't expect error, but got: {e:?}") + unreachable!("Didn't expect error, but got: {e:?}") } } _ => unreachable!(), From a596fc57c1bc8749c9cfdb83f8732706ab94d2d1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 12:51:40 -0700 Subject: [PATCH 22/23] address feedback --- native/spark-expr/src/math_funcs/abs.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 64cb2f13cb..67cf31f4de 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -644,10 +644,25 @@ mod tests { #[test] fn test_abs_f32_array() { with_fail_on_error(|fail_on_error| { - let input = Float32Array::from(vec![Some(-1f32), Some(f32::MIN), Some(f32::MAX), None]); + let input = Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::NEG_INFINITY), + Some(f32::INFINITY), + ]); let args = ColumnarValue::Array(Arc::new(input)); - let expected = - Float32Array::from(vec![Some(1f32), Some(f32::MAX), Some(f32::MAX), None]); + let expected = Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + ]); let fail_on_error_arg = ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); From a8c1d0a7ca6cd7b11245493ee5d82fa42dc99c44 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 Nov 2025 13:54:45 -0700 Subject: [PATCH 23/23] add negative zero in native code --- native/spark-expr/src/math_funcs/abs.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/native/spark-expr/src/math_funcs/abs.rs b/native/spark-expr/src/math_funcs/abs.rs index 67cf31f4de..5a16398ec4 100644 --- a/native/spark-expr/src/math_funcs/abs.rs +++ b/native/spark-expr/src/math_funcs/abs.rs @@ -652,6 +652,8 @@ mod tests { Some(f32::NAN), Some(f32::NEG_INFINITY), Some(f32::INFINITY), + Some(-0.0), + Some(0.0), ]); let args = ColumnarValue::Array(Arc::new(input)); let expected = Float32Array::from(vec![ @@ -662,6 +664,8 @@ mod tests { Some(f32::NAN), Some(f32::INFINITY), Some(f32::INFINITY), + Some(0.0), + Some(0.0), ]); let fail_on_error_arg = ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));