From 919038222472a04b2f5480f2484400528dcb9cd0 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Sun, 12 Oct 2025 15:54:10 +0400 Subject: [PATCH] Refactor bit_count --- native/core/src/execution/jni_api.rs | 2 + .../src/bitwise_funcs/bitwise_count.rs | 148 ------------------ native/spark-expr/src/bitwise_funcs/mod.rs | 2 - native/spark-expr/src/comet_scalar_funcs.rs | 5 +- .../org/apache/comet/serde/bitwise.scala | 25 +-- 5 files changed, 6 insertions(+), 176 deletions(-) delete mode 100644 native/spark-expr/src/bitwise_funcs/bitwise_count.rs diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index b17cfa1d9b..131b37a182 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -40,6 +40,7 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; use datafusion_comet_proto::spark_operator::Operator; +use datafusion_spark::function::bitwise::bit_count::SparkBitCount; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; @@ -332,6 +333,7 @@ fn prepare_datafusion_session_context( session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitCount::default())); // Must be the last one to override existing functions with the same name datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?; diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs deleted file mode 100644 index 4ab63e532c..0000000000 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ /dev/null @@ -1,148 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::{array::*, datatypes::DataType}; -use datafusion::common::{exec_err, internal_datafusion_err, internal_err, Result}; -use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; -use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; -use std::any::Any; -use std::sync::Arc; - -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct SparkBitwiseCount { - signature: Signature, - aliases: Vec, -} - -impl Default for SparkBitwiseCount { - fn default() -> Self { - Self::new() - } -} - -impl SparkBitwiseCount { - pub fn new() -> Self { - Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], - } - } -} - -impl ScalarUDFImpl for SparkBitwiseCount { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "bit_count" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Int32) - } - - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args: [ColumnarValue; 1] = args - .args - .try_into() - .map_err(|_| internal_datafusion_err!("bit_count expects exactly one argument"))?; - spark_bit_count(args) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -macro_rules! compute_op { - ($OPERAND:expr, $DT:ident) => {{ - let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| { - DataFusionError::Execution(format!( - "compute_op failed to downcast array to: {:?}", - stringify!($DT) - )) - })?; - - let result: Int32Array = operand - .iter() - .map(|x| x.map(|y| bit_count(y.into()))) - .collect(); - - Ok(Arc::new(result)) - }}; -} - -pub fn spark_bit_count(args: [ColumnarValue; 1]) -> Result { - match args { - [ColumnarValue::Array(array)] => { - let result: Result = match array.data_type() { - DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), - DataType::Int16 => compute_op!(array, Int16Array), - DataType::Int32 => compute_op!(array, Int32Array), - DataType::Int64 => compute_op!(array, Int64Array), - _ => exec_err!("bit_count can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()), - }; - result.map(ColumnarValue::Array) - } - [ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise count scalar path"), - } -} - -// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) -fn bit_count(i: i64) -> i32 { - let mut u = i as u64; - u = u - ((u >> 1) & 0x5555555555555555); - u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333); - u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f; - u = u + (u >> 8); - u = u + (u >> 16); - u = u + (u >> 32); - (u as i32) & 0x7f -} - -#[cfg(test)] -mod tests { - use datafusion::common::{cast::as_int32_array, Result}; - - use super::*; - - #[test] - fn bitwise_count_op() -> Result<()> { - let args = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(12345), - Some(89), - Some(-3456), - ]))); - let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); - - let ColumnarValue::Array(result) = spark_bit_count([args])? else { - unreachable!() - }; - - let result = as_int32_array(&result).expect("failed to downcast to In32Array"); - assert_eq!(result, expected); - - Ok(()) - } -} diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 3f148a6dc7..848fbc620f 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -mod bitwise_count; mod bitwise_not; -pub use bitwise_count::SparkBitwiseCount; pub use bitwise_not::SparkBitwiseNot; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 19fa11e641..cc8eda2189 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -21,8 +21,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, - spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, - SparkDateTrunc, SparkStringSpace, + spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseNot, SparkDateTrunc, + SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -191,7 +191,6 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())), - Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), ] diff --git a/spark/src/main/scala/org/apache/comet/serde/bitwise.scala b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala index 8215ea5dfa..919b6dd043 100644 --- a/spark/src/main/scala/org/apache/comet/serde/bitwise.scala +++ b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala @@ -127,27 +127,6 @@ object CometShiftLeft extends CometExpressionSerde[ShiftLeft] { } } -object CometBitwiseGet extends CometExpressionSerde[BitwiseGet] { - override def convert( - expr: BitwiseGet, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val argProto = exprToProto(expr.left, inputs, binding) - val posProto = exprToProto(expr.right, inputs, binding) - val bitGetScalarExpr = - scalarFunctionExprToProtoWithReturnType("bit_get", ByteType, argProto, posProto) - optExprWithInfo(bitGetScalarExpr, expr, expr.children: _*) - } -} +object CometBitwiseGet extends CometScalarFunction[BitwiseGet]("bit_get") -object CometBitwiseCount extends CometExpressionSerde[BitwiseCount] { - override def convert( - expr: BitwiseCount, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val childProto = exprToProto(expr.child, inputs, binding) - val bitCountScalarExpr = - scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto) - optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) - } -} +object CometBitwiseCount extends CometScalarFunction[BitwiseCount]("bit_count")