Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -580,15 +580,15 @@ class Analyzer(override val catalogManager: CatalogManager)
aggregations: Seq[NamedExpression],
groupByAliases: Seq[Alias],
groupingAttrs: Seq[Expression],
gid: Attribute): Seq[NamedExpression] = aggregations.map {
gid: Attribute): Seq[NamedExpression] = aggregations.map { agg =>
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
val aggsBuffer = ArrayBuffer[Expression]()
// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
replaceGroupingFunc(_, groupByExprs, gid).transformDown {
replaceGroupingFunc(agg, groupByExprs, gid).transformDown {
Copy link
Contributor

@cfmcgrady cfmcgrady Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between underscore?

Copy link
Contributor Author

@gaoyajun02 gaoyajun02 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the underscore, the aggsBuffer is outside the scope of the map function at runtime and it will save the results of all elements.
using normal parameters, the aggsBuffer will only be recreated each time inside the map function loop.

I suspect that Scala syntactic sugar in the conversion of the code made changes to cause, I also debugged this code many times before I found this difference, here is a simplified code to test separately.

    def testMap(seq: Seq[Int]): Seq[Int] = {
      seq.map {
        val buf = ArrayBuffer[Int]()
        _ match {
          case e: Int if e < 1 =>
            val r = e + 1
            println(s"add to buf: $r")
            buf += r
            r
          case e: Int if buf.contains(e) =>
            println("already in buf")
            0
          case e =>
            println("not in buf")
            e
        }
      }
    }

    testMap(Seq(0, 1))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed explanation.

// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3405,6 +3405,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}

test("SPARK-36339: References to grouping attributes should be replaced") {
withTempView("t") {
Seq("a", "a", "b").toDF("x").createOrReplaceTempView("t")
checkAnswer(
sql(
"""
|select count(x) c, x from t
|group by x grouping sets(x)
""".stripMargin),
Seq(Row(2, "a"), Row(1, "b")))
}
}

test("SPARK-31166: UNION map<null, null> and other maps should not fail") {
checkAnswer(
sql("(SELECT map()) UNION ALL (SELECT map(1, 2))"),
Expand Down