diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 12ba41145c20..f42ed15deb45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2923,7 +2923,22 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val extraAggExprs = new LinkedHashMap[Expression, NamedExpression] val transformed = exprs.map { e => if (!e.resolved) { - e + val aggregatedCondition = + Aggregate( + agg.groupingExpressions, + Alias(e, "havingCondition")() :: Nil, + agg.child) + val resolvedOperator = executeSameContext(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + if (resolvedOperator.resolved) { + buildAggExprList(resolvedAggregateFilter, agg, extraAggExprs) + } else { + e + } } else { buildAggExprList(e, agg, extraAggExprs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dd4a6535619f..037585be4dc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -5059,6 +5059,28 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-53094: Fix cube-related data quality problem") { + withTable("table1") { + withSQLConf() { + sql( + """CREATE TABLE table1(product string, amount bigint, + |region string) using csv""".stripMargin) + + sql("INSERT INTO table1 " + "VALUES('a', 100, 'east')") + sql("INSERT INTO table1 " + "VALUES('b', 200, 'east')") + sql("INSERT INTO table1 " + "VALUES('a', 150, 'west')") + sql("INSERT INTO table1 " + "VALUES('b', 250, 'west')") + sql("INSERT INTO table1 " + "VALUES('a', 120, 'east')") + + checkAnswer( + sql("select product, region, sum(amount) as s " + + "from table1 group by product, region with cube having count(product) > 2 " + + "order by s desc"), + Seq(Row(null, null, 820), Row(null, "east", 420), Row("a", null, 370))) + } + } + } } case class Foo(bar: Option[String])