diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala index 5e6f1b5a8840..8ce2452cc141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, RepartitionOperation, Statistics} import org.apache.spark.sql.catalyst.trees.TreePattern.{LOGICAL_QUERY_STAGE, REPARTITION_OPERATION, TreePattern} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec /** * The LogicalPlan wrapper for a [[QueryStageExec]], or a snippet of physical plan containing @@ -53,8 +54,14 @@ case class LogicalQueryStage( override def computeStats(): Statistics = { // TODO this is not accurate when there is other physical nodes above QueryStageExec. val physicalStats = physicalPlan.collectFirst { - case s: QueryStageExec => s - }.flatMap(_.computeStats()) + case a: BaseAggregateExec if a.groupingExpressions.isEmpty => + a.collectFirst { + case s: QueryStageExec => s.computeStats() + }.flatten.map { stat => + if (stat.rowCount.contains(0)) stat.copy(rowCount = Some(1)) else stat + } + case s: QueryStageExec => s.computeStats() + }.flatten if (physicalStats.isDefined) { logDebug(s"Physical stats available as ${physicalStats.get} for plan: $physicalPlan") } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 58936f5d8dc8..79f2b6b46577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2841,6 +2841,14 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-44040: Fix compute stats when AggregateExec nodes above QueryStageExec") { + val emptyDf = spark.range(1).where("false") + val aggDf1 = emptyDf.agg(sum("id").as("id")).withColumn("name", lit("df1")) + val aggDf2 = emptyDf.agg(sum("id").as("id")).withColumn("name", lit("df2")) + val unionDF = aggDf1.union(aggDf2) + checkAnswer(unionDF.select("id").distinct, Seq(Row(null))) + } } /**