diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c40e2e73f9..da452c2f15 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -70,6 +70,7 @@ use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctio use crate::execution::shuffle::CompressionCodec; use crate::execution::spark_plan::SparkPlan; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_comet_proto::{ spark_expression::{ self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, @@ -1183,17 +1184,42 @@ impl PhysicalPlanner { false, )?); - Ok(( - scans, - Arc::new(SparkPlan::new( - spark_plan.plan_id, - join, - vec![ - Arc::clone(&join_params.left), - Arc::clone(&join_params.right), - ], - )), - )) + if join.filter.is_some() { + // SMJ with join filter produces lots of tiny batches + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new( + Arc::::clone(&join), + self.session_ctx + .state() + .config_options() + .execution + .batch_size, + )); + Ok(( + scans, + Arc::new(SparkPlan::new_with_additional( + spark_plan.plan_id, + coalesce_batches, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + vec![join], + )), + )) + } else { + Ok(( + scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + join, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + )), + )) + } } OpStruct::HashJoin(join) => { let (join_params, scans) = self.parse_join_parameters( diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 8bff6b5fbd..52a0d5e180 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -567,7 +567,7 @@ class CometSparkSessionExtensions case op: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) && - op.children.forall(isCometNative(_)) => + op.children.forall(isCometNative) => val newOp = transform1(op) newOp match { case Some(nativeOp) => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index dc081b196b..7ed3725bed 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2859,11 +2859,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case RightOuter => JoinType.RightOuter case FullOuter => JoinType.FullOuter case LeftSemi => JoinType.LeftSemi - // TODO: DF SMJ with join condition fails TPCH q21 - case LeftAnti if condition.isEmpty => JoinType.LeftAnti - case LeftAnti => - withInfo(join, "LeftAnti SMJ join with condition is not supported") - return None + case LeftAnti => JoinType.LeftAnti case _ => // Spark doesn't support other join types withInfo(op, s"Unsupported join type ${join.joinType}") diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index ad1aef4a8f..d756da1515 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -216,7 +216,7 @@ class CometJoinSuite extends CometTestBase { v.toDouble, v.toString, v % 2 == 0, - v.toString().getBytes, + v.toString.getBytes, Decimal(v)) withParquetTable((0 until 10).map(i => manyTypes(i, i % 5)), "tbl_a") { @@ -294,6 +294,7 @@ class CometJoinSuite extends CometTestBase { test("SortMergeJoin without join filter") { withSQLConf( + CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { @@ -338,9 +339,9 @@ class CometJoinSuite extends CometTestBase { } } - // https://github.com/apache/datafusion-comet/issues/398 - ignore("SortMergeJoin with join filter") { + test("SortMergeJoin with join filter") { withSQLConf( + CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true", CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { @@ -391,9 +392,6 @@ class CometJoinSuite extends CometTestBase { "AND tbl_a._2 >= tbl_b._1") checkSparkAnswerAndOperator(df9) - // TODO: Enable these tests after fixing the issue: - // https://github.com/apache/datafusion-comet/issues/861 - /* val df10 = sql( "SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b ON tbl_a._2 = tbl_b._1 " + "AND tbl_a._2 >= tbl_b._1") @@ -403,7 +401,6 @@ class CometJoinSuite extends CometTestBase { "SELECT * FROM tbl_b LEFT ANTI JOIN tbl_a ON tbl_a._2 = tbl_b._1 " + "AND tbl_a._2 >= tbl_b._1") checkSparkAnswerAndOperator(df11) - */ } } }