diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f031f0816db1..1a70588c2119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -580,7 +580,7 @@ class Analyzer(override val catalogManager: CatalogManager) aggregations: Seq[NamedExpression], groupByAliases: Seq[Alias], groupingAttrs: Seq[Expression], - gid: Attribute): Seq[NamedExpression] = aggregations.map { + gid: Attribute): Seq[NamedExpression] = aggregations.map { agg => // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. val aggsBuffer = ArrayBuffer[Expression]() @@ -588,7 +588,7 @@ class Analyzer(override val catalogManager: CatalogManager) def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } - replaceGroupingFunc(_, groupByExprs, gid).transformDown { + replaceGroupingFunc(agg, groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ed3b47947a5d..032ddbbcebf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3405,6 +3405,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-36339: References to grouping attributes should be replaced") { + withTempView("t") { + Seq("a", "a", "b").toDF("x").createOrReplaceTempView("t") + checkAnswer( + sql( + """ + |select count(x) c, x from t + |group by x grouping sets(x) + """.stripMargin), + Seq(Row(2, "a"), Row(1, "b"))) + } + } + test("SPARK-31166: UNION map and other maps should not fail") { checkAnswer( sql("(SELECT map()) UNION ALL (SELECT map(1, 2))"),