diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 6ad0793fb642..d02f12d67e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -28,14 +28,17 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_ /** * The base class of two rules in the normal and AQE Optimizer. It simplifies query plans with * empty or non-empty relations: - * 1. Binary-node Logical Plans + * 1. Higher-node Logical Plans + * - Union with all empty children. + * 2. Binary-node Logical Plans * - Join with one or two empty children (including Intersect/Except). * - Left semi Join * Right side is non-empty and condition is empty. Eliminate join to its left side. * - Left anti join * Right side is non-empty and condition is empty. Eliminate join to an empty * [[LocalRelation]]. - * 2. Unary-node Logical Plans + * 3. Unary-node Logical Plans + * - Project/Filter/Sample with all empty children. * - Limit/Repartition with all empty children. * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. @@ -59,6 +62,31 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Union if p.children.exists(isEmpty) => + val newChildren = p.children.filterNot(isEmpty) + if (newChildren.isEmpty) { + empty(p) + } else { + val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head + val outputs = newPlan.output.zip(p.output) + // the original Union may produce different output attributes than the new one so we alias + // them if needed + if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) { + newPlan + } else { + val newOutput = outputs.map { case (newAttr, oldAttr) => + if (newAttr.exprId == oldAttr.exprId) { + newAttr + } else { + val newExplicitMetadata = + if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None + Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata) + } + } + Project(newOutput, newPlan) + } + } + // Joins on empty LocalRelations generated from streaming sources are not eliminated // as stateful streaming joins need to perform other state management operations other than // just processing the input data. @@ -98,7 +126,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup p } + // the only case can be matched here is that LogicalQueryStage is empty + case p: LeafNode if !p.isInstanceOf[LocalRelation] && isEmpty(p) => empty(p) + case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmpty) => p match { + case _: Project => empty(p) + case _: Filter => empty(p) + case _: Sample => empty(p) case _: Sort => empty(p) case _: GlobalLimit if !p.isStreaming => empty(p) case _: LocalLimit if !p.isStreaming => empty(p) @@ -128,53 +162,11 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup } /** - * This rule runs in the normal optimizer and optimizes more cases - * compared to [[PropagateEmptyRelationBase]]: - * 1. Higher-node Logical Plans - * - Union with all empty children. - * 2. Unary-node Logical Plans - * - Project/Filter/Sample with all empty children. - * - * The reason why we don't apply this rule at AQE optimizer side is: the benefit is not big enough - * and it may introduce extra exchanges. + * This rule runs in the normal optimizer */ object PropagateEmptyRelation extends PropagateEmptyRelationBase { - private def applyFunc: PartialFunction[LogicalPlan, LogicalPlan] = { - case p: Union if p.children.exists(isEmpty) => - val newChildren = p.children.filterNot(isEmpty) - if (newChildren.isEmpty) { - empty(p) - } else { - val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head - val outputs = newPlan.output.zip(p.output) - // the original Union may produce different output attributes than the new one so we alias - // them if needed - if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) { - newPlan - } else { - val outputAliases = outputs.map { case (newAttr, oldAttr) => - val newExplicitMetadata = - if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None - Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata) - } - Project(outputAliases, newPlan) - } - } - - case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmpty) && canPropagate(p) => - empty(p) - } - - // extract the pattern avoid conflict with commonApplyFunc - private def canPropagate(plan: LogicalPlan): Boolean = plan match { - case _: Project => true - case _: Filter => true - case _: Sample => true - case _ => false - } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) { - applyFunc.orElse(commonApplyFunc) + commonApplyFunc } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index ea2fb1c3130a..bab77515f79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys /** @@ -32,14 +33,28 @@ import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys */ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { override protected def isEmpty(plan: LogicalPlan): Boolean = - super.isEmpty(plan) || getRowCount(plan).contains(0) + super.isEmpty(plan) || getEstimatedRowCount(plan).contains(0) override protected def nonEmpty(plan: LogicalPlan): Boolean = - super.nonEmpty(plan) || getRowCount(plan).exists(_ > 0) + super.nonEmpty(plan) || getEstimatedRowCount(plan).exists(_ > 0) - private def getRowCount(plan: LogicalPlan): Option[BigInt] = plan match { + // The returned value follows: + // - 0 means the plan must produce 0 row + // - positive value means an estimated row count which can be over-estimated + // - none means the plan has not materialized or the plan can not be estimated + private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match { case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized => stage.getRuntimeStatistics.rowCount + + case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty && + agg.child.isInstanceOf[QueryStageExec] => + val stage = agg.child.asInstanceOf[QueryStageExec] + if (stage.isMaterialized) { + stage.getRuntimeStatistics.rowCount + } else { + None + } + case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 51d476703a76..a29989cc06c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1406,6 +1406,56 @@ class AdaptiveQueryExecSuite } } + test("SPARK-35442: Support propagate empty relation through aggregate") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( + "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key") + assert(!plan1.isInstanceOf[LocalTableScanExec]) + assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) + + val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( + "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1") + assert(!plan2.isInstanceOf[LocalTableScanExec]) + assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) + + val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult( + "SELECT count(*) FROM testData WHERE value = 'no_match'") + assert(!plan3.isInstanceOf[LocalTableScanExec]) + assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec]) + } + } + + test("SPARK-35442: Support propagate empty relation through union") { + def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = { + assert( + collect(plan) { + case u: UnionExec => u + }.size == numUnion) + } + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( + """ + |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key + |UNION ALL + |SELECT key, 1 FROM testData + |""".stripMargin) + checkNumUnion(plan1, 1) + checkNumUnion(adaptivePlan1, 0) + assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) + + val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( + """ + |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key + |UNION ALL + |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match' + |""".stripMargin) + checkNumUnion(plan2, 1) + checkNumUnion(adaptivePlan2, 0) + assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) + } + } + test("SPARK-32753: Only copy tags to node with no tags") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { withTempView("v1") { @@ -1794,7 +1844,8 @@ class AdaptiveQueryExecSuite test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") { withTable("t") { withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", - SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { spark.sql("CREATE TABLE t (c1 int) USING PARQUET") val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") assert( @@ -2261,7 +2312,8 @@ class AdaptiveQueryExecSuite test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { // Spark estimates a string column as 20 bytes so with 60k rows, these relations should be // estimated at ~120m bytes which is greater than the broadcast join threshold. val joinKeyOne = "00112233445566778899"