From 1181c146830ad3ba0c3b775d64a3c273ede4f3ed Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Sun, 29 Dec 2024 13:18:01 +0530 Subject: [PATCH 01/10] wip: implemention of array_repeat is done. Test cases are pending --- native/core/src/execution/planner.rs | 30 +++++++++++++++++++ native/proto/src/proto/expr.proto | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 6 ++++ .../apache/comet/CometExpressionSuite.scala | 22 ++++++++++++++ 4 files changed, 59 insertions(+) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5a35c62e33..27188611f7 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -98,6 +98,7 @@ use datafusion_expr::{ AggregateUDF, ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_nested::repeat::array_repeat_udf; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::LexOrdering; @@ -719,6 +720,35 @@ impl PhysicalPlanner { expr.legacy_negative_index, ))) } + ExprStruct::ArrayRepeat(expr) => { + let value = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let count = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + println!("value {:?}", value); + println!("count {:?}", count); + let return_type = value.data_type(&input_schema)?; + let args = vec![Arc::clone(&value), count.clone()]; + + let datafusion_array_repeat = array_repeat_udf(); + let array_repeat_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "array_repat", + datafusion_array_repeat, + args, + return_type, + )); + + let is_null_expr: Arc = Arc::new(IsNullExpr::new(count)); + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null)); + + let case_expr = CaseExpr::try_new( + None, + vec![(is_null_expr, null_literal_expr)], + Some(array_repeat_expr), + )?; + Ok(Arc::new(case_expr)) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 7a8ea78d57..823dac8cee 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -84,6 +84,7 @@ message Expr { GetArrayStructFields get_array_struct_fields = 57; BinaryExpr array_append = 58; ArrayInsert array_insert = 59; + BinaryExpr array_repeat = 60; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 518fa06858..8a7f03b46a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } + case expr if expr.prettyName == "array_repeat" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayRepeat(binaryExpr)) case _ if expr.prettyName == "array_append" => createBinaryExpr( expr.children(0), diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cce7cb20a1..f8afd45e39 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2517,4 +2517,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswer(df.select("arrUnsupportedArgs")) } } + test("array_repeat") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 3) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(5, _2) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, null) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(2, null) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(null, 3) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(null, _3) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 0) from t1")) +// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, -1) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(8, 5) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_8, 2) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(true, 3) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(1.5, 2) from t1")) + } + } + } } From 5d9817a801ea6c6090eb68a76f529d1f5a49aca3 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Sun, 29 Dec 2024 14:08:57 +0530 Subject: [PATCH 02/10] fixed the clone issue for args --- native/core/src/execution/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 27188611f7..635a01e57c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -728,7 +728,7 @@ impl PhysicalPlanner { println!("value {:?}", value); println!("count {:?}", count); let return_type = value.data_type(&input_schema)?; - let args = vec![Arc::clone(&value), count.clone()]; + let args = vec![Arc::clone(&value), Arc::clone(&count)]; let datafusion_array_repeat = array_repeat_udf(); let array_repeat_expr: Arc = Arc::new(ScalarFunctionExpr::new( From 628092d6163614246dfe3d8bfb85c60812fb92b6 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Sun, 12 Jan 2025 09:50:15 +0530 Subject: [PATCH 03/10] fix arary_repeat function name --- native/core/src/execution/planner.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 96b1c443d9..8d6ab9b6d9 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -737,24 +737,22 @@ impl PhysicalPlanner { Ok(array_has_expr) } ExprStruct::ArrayRepeat(expr) => { - let value = + let src_expr = self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let count = + let count_expr = self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; - println!("value {:?}", value); - println!("count {:?}", count); - let return_type = value.data_type(&input_schema)?; - let args = vec![Arc::clone(&value), Arc::clone(&count)]; + let return_type = src_expr.data_type(&input_schema)?; + let args = vec![Arc::clone(&src_expr), Arc::clone(&count_expr)]; let datafusion_array_repeat = array_repeat_udf(); let array_repeat_expr: Arc = Arc::new(ScalarFunctionExpr::new( - "array_repat", + "array_repeat", datafusion_array_repeat, args, return_type, )); - let is_null_expr: Arc = Arc::new(IsNullExpr::new(count)); + let is_null_expr: Arc = Arc::new(IsNullExpr::new(count_expr)); let null_literal_expr: Arc = Arc::new(Literal::new(ScalarValue::Null)); From cb904df7c9aabfe996b5fb7ae2043d0cda2bf735 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Sat, 18 Jan 2025 21:22:09 +0530 Subject: [PATCH 04/10] updated the return type of expression, have to fix tests --- native/core/src/execution/planner.rs | 3 ++- .../apache/comet/CometExpressionSuite.scala | 21 ++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b0c1e117c6..4e411fb158 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -771,7 +771,8 @@ impl PhysicalPlanner { self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; let count_expr = self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; - let return_type = src_expr.data_type(&input_schema)?; + let element_type = src_expr.data_type(&input_schema)?; + let return_type = DataType::List(Arc::new(Field::new("item", element_type, true))); let args = vec![Arc::clone(&src_expr), Arc::clone(&count_expr)]; let datafusion_array_repeat = array_repeat_udf(); diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index a2c147fea5..4c3b4fe3e0 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2559,18 +2559,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled, 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1") -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 3) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(5, _2) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, null) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(2, null) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(null, 3) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(null, _3) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 0) from t1")) -// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, -1) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(8, 5) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(_8, 2) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(true, 3) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(1.5, 2) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(array(_2,_3), _2) from t1 where _2 is null")) + // checkSparkAnswerAndOperator( + // sql("SELECT array_repeat(_3, 0) from t1 where _3 is not null")) + // checkSparkAnswerAndOperator( + // sql("SELECT array_repeat(_3, 2) from t1 where _3 is not null")) + // checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 2) from t1 where _2 is null")) + // checkSparkAnswerAndOperator( + // sql("SELECT array_repeat(case when _2 = _3 THEN _8 ELSE null END, 2) from t1")) } } } From d31c880d9b71c30b5b9b956bc07a9a2f30857660 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Wed, 22 Jan 2025 21:55:37 +0530 Subject: [PATCH 05/10] updated array_repeat planner code --- native/core/src/execution/planner.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 4e411fb158..6f37639ca2 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -771,11 +771,25 @@ impl PhysicalPlanner { self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; let count_expr = self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; - let element_type = src_expr.data_type(&input_schema)?; - let return_type = DataType::List(Arc::new(Field::new("item", element_type, true))); + // Cast count_expr from Int32 to Int64 to support df count argument + let count_expr: Arc = match count_expr.data_type(&input_schema.clone())? { + DataType::Int32 => Arc::new(CastExpr::new( + count_expr, + DataType::Int64, + Some(CastOptions::default()), + )), + _ => count_expr, + }; + let args = vec![Arc::clone(&src_expr), Arc::clone(&count_expr)]; let datafusion_array_repeat = array_repeat_udf(); + let data_types: Vec = vec![ + src_expr.data_type(&Arc::clone(&input_schema))?, + count_expr.data_type(&Arc::clone(&input_schema))?, + ]; + let return_type = datafusion_array_repeat.return_type(&data_types)?; + let array_repeat_expr: Arc = Arc::new(ScalarFunctionExpr::new( "array_repeat", datafusion_array_repeat, From 3d6d2e30028a1a89e4eb3209e3caf2712c515902 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Wed, 22 Jan 2025 22:01:48 +0530 Subject: [PATCH 06/10] wip: test --- .../org/apache/comet/CometExpressionSuite.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 4c3b4fe3e0..56eccbe69e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2559,15 +2559,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled, 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1") + // checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, null) from t1")) checkSparkAnswerAndOperator( - sql("SELECT array_repeat(array(_2,_3), _2) from t1 where _2 is null")) - // checkSparkAnswerAndOperator( - // sql("SELECT array_repeat(_3, 0) from t1 where _3 is not null")) - // checkSparkAnswerAndOperator( - // sql("SELECT array_repeat(_3, 2) from t1 where _3 is not null")) - // checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 2) from t1 where _2 is null")) - // checkSparkAnswerAndOperator( - // sql("SELECT array_repeat(case when _2 = _3 THEN _8 ELSE null END, 2) from t1")) + sql("SELECT array_repeat(_4, 5) from t1 where _4 is not null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) } } } From 3b60be839cc682c19b4d585413e37491c3230ad8 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Wed, 22 Jan 2025 22:44:36 +0530 Subject: [PATCH 07/10] updated the test --- .../test/scala/org/apache/comet/CometExpressionSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index a186a1e4a5..f480068c4d 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2334,7 +2334,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - ignore("get_struct_field - select subset of struct") { + test("get_struct_field with DataFusion ParquetExec - select subset of struct") { withTempPath { dir => // create input file with Comet disabled withSQLConf(CometConf.COMET_ENABLED.key -> "false") { @@ -2698,9 +2698,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled, 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1") - // checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, null) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, _4) from t1")) checkSparkAnswerAndOperator( - sql("SELECT array_repeat(_4, 5) from t1 where _4 is not null")) + sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) From 76a4ff66b46faa1aeeae029a03b1a55344a5d81f Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Sun, 26 Jan 2025 18:27:24 +0530 Subject: [PATCH 08/10] added test case if the count variable is NULL --- spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 061084ae4e..f7fe25de3f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2713,6 +2713,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator( sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, _3) from t1 where _3 is null")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) } From ab0bbb7dcbde635cfaa237d62c1e2d3e484622a5 Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Mon, 27 Jan 2025 21:38:55 +0530 Subject: [PATCH 09/10] fix for failing test --- native/core/src/execution/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c39992e34c..17e2b84336 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -784,7 +784,7 @@ impl PhysicalPlanner { self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; // Cast count_expr from Int32 to Int64 to support df count argument let count_expr: Arc = - match count_expr.data_type(&input_schema.clone())? { + match count_expr.data_type(Arc::::clone(&input_schema))? { DataType::Int32 => Arc::new(CastExpr::new( count_expr, DataType::Int64, From ea67b7176c59df190eae423bef89f44fbb5b0b0d Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Tue, 28 Jan 2025 21:12:43 +0530 Subject: [PATCH 10/10] fixed error --- native/core/src/execution/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 17e2b84336..bef4d7c037 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -784,7 +784,7 @@ impl PhysicalPlanner { self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; // Cast count_expr from Int32 to Int64 to support df count argument let count_expr: Arc = - match count_expr.data_type(Arc::::clone(&input_schema))? { + match count_expr.data_type(&Arc::clone(&input_schema))? { DataType::Int32 => Arc::new(CastExpr::new( count_expr, DataType::Int64,