diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 294922f2f1..c43230f49f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -66,6 +66,7 @@ use datafusion::{ }; use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}; use datafusion_functions_nested::concat::ArrayAppend; +use datafusion_functions_nested::remove::array_remove_all_udf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use crate::execution::shuffle::CompressionCodec; @@ -735,6 +736,35 @@ impl PhysicalPlanner { )); Ok(array_has_expr) } + ExprStruct::ArrayRemove(expr) => { + let src_array_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let key_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![Arc::clone(&src_array_expr), Arc::clone(&key_expr)]; + let return_type = src_array_expr.data_type(&input_schema)?; + + let datafusion_array_remove = array_remove_all_udf(); + + let array_remove_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "array_remove", + datafusion_array_remove, + args, + return_type, + )); + let is_null_expr: Arc = Arc::new(IsNullExpr::new(key_expr)); + + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null)); + + let case_expr = CaseExpr::try_new( + None, + vec![(is_null_expr, null_literal_expr)], + Some(array_remove_expr), + )?; + + Ok(Arc::new(case_expr)) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index e76ecdccf1..8e3bc60b0f 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -85,6 +85,7 @@ message Expr { BinaryExpr array_append = 58; ArrayInsert array_insert = 59; BinaryExpr array_contains = 60; + BinaryExpr array_remove = 61; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7a69e63096..8186486516 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } + case expr if expr.prettyName == "array_remove" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayRemove(binaryExpr)) case expr if expr.prettyName == "array_contains" => createBinaryExpr( expr.children(0), diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index afdf8601de..8c2759a384 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2529,4 +2529,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); } } + + test("array_remove") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + } + } + } }