Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,23 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}

private def hasGroupingId(expr: Seq[Expression]): Boolean = {
expr.exists(_.collectFirst {
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
}.isDefined)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
case x: GroupingSets =>
case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
failAnalysis(
s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()

// Expand works by setting grouping expressions to null as determined by the bitmasks. To
Expand Down
34 changes: 34 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}

test("grouping sets when aggregate functions containing groupBy columns") {
checkAnswer(
sql("select course, sum(earnings) as sum from courseSales group by course, earnings " +
"grouping sets((), (course), (course, earnings)) " +
"order by course, sum"),
Row(null, 113000.0) ::
Row("Java", 20000.0) ::
Row("Java", 30000.0) ::
Row("Java", 50000.0) ::
Row("dotNET", 5000.0) ::
Row("dotNET", 10000.0) ::
Row("dotNET", 48000.0) ::
Row("dotNET", 63000.0) :: Nil
)

checkAnswer(
sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " +
"group by course, earnings grouping sets((), (course), (course, earnings)) " +
"order by course, sum"),
Row(null, 113000.0, 3) ::
Row("Java", 20000.0, 0) ::
Row("Java", 30000.0, 0) ::
Row("Java", 50000.0, 1) ::
Row("dotNET", 5000.0, 0) ::
Row("dotNET", 10000.0, 0) ::
Row("dotNET", 48000.0, 0) ::
Row("dotNET", 63000.0, 1) :: Nil
)
}

test("cube") {
checkAnswer(
sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"),
Expand Down Expand Up @@ -2103,6 +2133,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql("select course, year, grouping_id(course, year) from courseSales group by course, year")
}
assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup")
error = intercept[AnalysisException] {
sql("select course, year, grouping__id from courseSales group by cube(course, year)")
}
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
}

test("SPARK-13056: Null in map value causes NPE") {
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -123,60 +123,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
assertBroadcastNestedLoopJoin(spark_10484_4)
}

createQueryTest("SPARK-8976 Wrong Result for Rollup #1",
"""
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP
""".stripMargin)

createQueryTest("SPARK-8976 Wrong Result for Rollup #2",
"""
SELECT
count(*) AS cnt,
key % 5 as k1,
key-5 as k2,
GROUPING__ID as k3
FROM src group by key%5, key-5
WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin)

createQueryTest("SPARK-8976 Wrong Result for Rollup #3",
"""
SELECT
count(*) AS cnt,
key % 5 as k1,
key-5 as k2,
GROUPING__ID as k3
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin)

createQueryTest("SPARK-8976 Wrong Result for CUBE #1",
"""
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE
""".stripMargin)

createQueryTest("SPARK-8976 Wrong Result for CUBE #2",
"""
SELECT
count(*) AS cnt,
key % 5 as k1,
key-5 as k2,
GROUPING__ID as k3
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin)

createQueryTest("SPARK-8976 Wrong Result for GroupingSet",
"""
SELECT
count(*) AS cnt,
key % 5 as k1,
key-5 as k2,
GROUPING__ID as k3
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin)

createQueryTest("insert table with generator with column name",
"""
| CREATE TABLE gen_tmp (key Int);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use grouping_id() instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hive does not have such a function.

SPARK-8976 Wrong Result for GroupingSet *** FAILED *** (316 milliseconds)
[info]   Failed to generate golden answer for query:
[info]   Error: FAILED: SemanticException [Error 10011]: Line 5:8 Invalid function 'grouping_id'
[info]   org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10011]: Line 5:8 Invalid function 'grouping_id'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have the results, just copy them as golden files (or copy them into test cases).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,116 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
}

test("SPARK-8976 Wrong Result for Rollup #1") {
checkAnswer(sql(
"SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"),
Seq(
(113, 3, 0),
(91, 0, 0),
(500, null, 1),
(84, 1, 0),
(105, 2, 0),
(107, 4, 0)
).map(i => Row(i._1, i._2, i._3)))
}

test("SPARK-8976 Wrong Result for Rollup #2") {
checkAnswer(sql(
"""
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|FROM src GROUP BY key%5, key-5
|WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin),
Seq(
(1, 0, 5, 0),
(1, 0, 15, 0),
(1, 0, 25, 0),
(1, 0, 60, 0),
(1, 0, 75, 0),
(1, 0, 80, 0),
(1, 0, 100, 0),
(1, 0, 140, 0),
(1, 0, 145, 0),
(1, 0, 150, 0)
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("SPARK-8976 Wrong Result for Rollup #3") {
checkAnswer(sql(
"""
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
|WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin),
Seq(
(1, 0, 5, 0),
(1, 0, 15, 0),
(1, 0, 25, 0),
(1, 0, 60, 0),
(1, 0, 75, 0),
(1, 0, 80, 0),
(1, 0, 100, 0),
(1, 0, 140, 0),
(1, 0, 145, 0),
(1, 0, 150, 0)
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("SPARK-8976 Wrong Result for CUBE #1") {
checkAnswer(sql(
"SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"),
Seq(
(113, 3, 0),
(91, 0, 0),
(500, null, 1),
(84, 1, 0),
(105, 2, 0),
(107, 4, 0)
).map(i => Row(i._1, i._2, i._3)))
}

test("SPARK-8976 Wrong Result for CUBE #2") {
checkAnswer(sql(
"""
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
|WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin),
Seq(
(1, null, -3, 2),
(1, null, -1, 2),
(1, null, 3, 2),
(1, null, 4, 2),
(1, null, 5, 2),
(1, null, 6, 2),
(1, null, 12, 2),
(1, null, 14, 2),
(1, null, 15, 2),
(1, null, 22, 2)
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("SPARK-8976 Wrong Result for GroupingSet") {
checkAnswer(sql(
"""
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
|GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10
""".stripMargin),
Seq(
(1, null, -3, 2),
(1, null, -1, 2),
(1, null, 3, 2),
(1, null, 4, 2),
(1, null, 5, 2),
(1, null, 6, 2),
(1, null, 12, 2),
(1, null, 14, 2),
(1, null, 15, 2),
(1, null, 22, 2)
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("SPARK-10562: partition by column with mixed case name") {
withTable("tbl10562") {
val df = Seq(2012 -> "a").toDF("Year", "val")
Expand Down