From ae1186f4be87b2136c3e55bf4ae3d41c58b03142 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 21 Mar 2021 09:32:47 +0100 Subject: [PATCH 01/15] [SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions --- .../sql/catalyst/expressions/grouping.scala | 19 ++++++++++++++++++- .../sql/catalyst/optimizer/Optimizer.scala | 15 +++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 17 +++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index f843c1a2d359..89c46dc202ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -155,3 +155,20 @@ object GroupingID { if (SQLConf.get.integerGroupingIdEnabled) IntegerType else LongType } } + +/** + * Wrapper expression to avoid further optizations of child + */ +case class GroupingExpression(child: Expression) extends UnaryExpression { + override def eval(input: InternalRow): Any = { + child.eval(input) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } + + override def dataType: DataType = { + child.dataType + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3e3550d5da89..6f0508c10d19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -870,8 +870,19 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - agg.copy(aggregateExpressions = buildCleanedProjectList( - p.projectList, agg.aggregateExpressions)) + val complexGroupingExpressions = + ExpressionSet(agg.groupingExpressions.filter(_.children.nonEmpty)) + + def wrapGroupingExpression(e: Expression): Expression = e match { + case _: AggregateExpression => e + case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e) + case _ => e.mapChildren(wrapGroupingExpression) + } + + val wrappedAggregateExpressions = + agg.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression]) + agg.copy(aggregateExpressions = + buildCleanedProjectList(p.projectList, wrappedAggregateExpressions)) } case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) if isRenaming(l1, l2) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 00cbd73533ab..b88ae4c3b6fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4116,6 +4116,23 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-34581: Don't optimize out grouping expressions from aggregate expressions") { + withTempView("t") { + Seq[Integer](null, 1, 2, 3, null).toDF("id").createOrReplaceTempView("t") + + val df = spark.sql( + """ + |SELECT not(id), c + |FROM ( + | SELECT t.id IS NULL AS id, count(*) AS c + | FROM t + | GROUP BY t.id IS NULL + |) t + |""".stripMargin) + checkAnswer(df, Row(true, 3) :: Row(false, 2) :: Nil) + } + } } case class Foo(bar: Option[String]) From 5ab9f75cfcd403d38fde6d0e39a319fc550c10f1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 21 Mar 2021 14:28:41 +0100 Subject: [PATCH 02/15] comment fix --- .../org/apache/spark/sql/catalyst/expressions/grouping.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 89c46dc202ca..7d893e72423c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -157,7 +157,8 @@ object GroupingID { } /** - * Wrapper expression to avoid further optizations of child + * Wrapper expression to avoid further optimizations between the parent and child of this + * expression. */ case class GroupingExpression(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = { From 2293fd40449b9b540de9aa6614b6ca10698b7433 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 22 Mar 2021 08:38:17 +0100 Subject: [PATCH 03/15] move logic to the beginning of optimization, simplify test --- .../sql/catalyst/optimizer/Optimizer.scala | 16 ++--------- .../catalyst/optimizer/finishAnalysis.scala | 28 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++---- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6f0508c10d19..391f8e700cb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -147,6 +147,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView, ReplaceExpressions, RewriteNonCorrelatedExists, + WrapGroupingExpressions, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager)) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -870,19 +871,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - val complexGroupingExpressions = - ExpressionSet(agg.groupingExpressions.filter(_.children.nonEmpty)) - - def wrapGroupingExpression(e: Expression): Expression = e match { - case _: AggregateExpression => e - case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e) - case _ => e.mapChildren(wrapGroupingExpression) - } - - val wrappedAggregateExpressions = - agg.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression]) - agg.copy(aggregateExpressions = - buildCleanedProjectList(p.projectList, wrappedAggregateExpressions)) + agg.copy(aggregateExpressions = buildCleanedProjectList( + p.projectList, agg.aggregateExpressions)) } case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) if isRenaming(l1, l2) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 897d74a0f947..6fa9f9062a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -66,6 +66,34 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { } } +/** + * Wrap complex grouping expression in aggregate expressions without aggregate function into + * `GroupingExpression` nodes so as to avoid further optimizations between the expression and its + * parent. + * + * This is required as further optimizations could change the grouping expression and so make the + * aggregate expression invalid. + */ +object WrapGroupingExpressions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case a: Aggregate => + val complexGroupingExpressions = + ExpressionSet(a.groupingExpressions.filter(_.children.nonEmpty)) + + def wrapGroupingExpression(e: Expression): Expression = e match { + case _: GroupingExpression => e + case _: AggregateExpression => e + case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e) + case _ => e.mapChildren(wrapGroupingExpression) + } + + a.copy(aggregateExpressions = + a.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression])) + } + } +} + /** * Computes the current date and time to make sure we return the same result in a single query. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b88ae4c3b6fb..5ca55df3f77a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4123,12 +4123,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val df = spark.sql( """ - |SELECT not(id), c - |FROM ( - | SELECT t.id IS NULL AS id, count(*) AS c - | FROM t - | GROUP BY t.id IS NULL - |) t + |SELECT not(t.id IS NULL), count(*) AS c + |FROM t + |GROUP BY t.id IS NULL |""".stripMargin) checkAnswer(df, Row(true, 3) :: Row(false, 2) :: Nil) } From 3de19cae711790a4c520221da1965277011f6f65 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 22 Mar 2021 09:11:43 +0100 Subject: [PATCH 04/15] regenerate approved plans --- .../approved-plans-v1_4/q62.sf100/explain.txt | 2 +- .../tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt | 2 +- .../approved-plans-v1_4/q99.sf100/explain.txt | 2 +- .../tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt index 41462a738912..a876fb70a86b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#13, web_name#10, sum# Keys [3]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#13, web_name#10] Functions [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [substr(w_warehouse_name#16, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#13, web_name#10, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [groupingexpression(substr(w_warehouse_name#16, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#13, web_name#10, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#13, web_name#10, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt index b5aa53f2de12..802c9a5f3c3a 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, web_name#13, sum#2 Keys [3]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, web_name#13] Functions [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [substr(w_warehouse_name#7, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, web_name#13, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [groupingexpression(substr(w_warehouse_name#7, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, web_name#13, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#10, web_name#13, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt index 9d653586e519..76553805e743 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#10, cc_name#13, sum#2 Keys [3]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#10, cc_name#13] Functions [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [substr(w_warehouse_name#16, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [groupingexpression(substr(w_warehouse_name#16, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt index c05e9595220c..ff180769e859 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, cc_name#13, sum#24 Keys [3]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, cc_name#13] Functions [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [substr(w_warehouse_name#7, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [groupingexpression(substr(w_warehouse_name#7, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] From 6e05f14c1a21e2894c1a4a85546a790189a43f5a Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 23 Mar 2021 14:38:38 +0100 Subject: [PATCH 05/15] define GroupingExpression as TaggingExpression --- .../expressions/constraintExpressions.scala | 6 ++++++ .../sql/catalyst/expressions/grouping.scala | 20 +------------------ 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 5bfae7b77e09..cdaed2bf404e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -39,3 +39,9 @@ case class KnownNotNull(child: Expression) extends TaggingExpression { } case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression + +/** + * Wrapper expression to avoid further optimizations between the parent and child of this + * expression. + */ +case class GroupingExpression(child: Expression) extends TaggingExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 7d893e72423c..f843c1a2d359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -155,21 +155,3 @@ object GroupingID { if (SQLConf.get.integerGroupingIdEnabled) IntegerType else LongType } } - -/** - * Wrapper expression to avoid further optimizations between the parent and child of this - * expression. - */ -case class GroupingExpression(child: Expression) extends UnaryExpression { - override def eval(input: InternalRow): Any = { - child.eval(input) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - child.genCode(ctx) - } - - override def dataType: DataType = { - child.dataType - } -} From 09f1a85d8312e56c886b31f26ae87d245288d2b9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 24 Mar 2021 10:59:17 +0100 Subject: [PATCH 06/15] move test to SQLQueryTestSuite --- .../test/resources/sql-tests/inputs/group-by.sql | 5 +++++ .../resources/sql-tests/results/group-by.sql.out | 13 ++++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 14 -------------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 6ee101473975..8f03e08f3cbc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -179,3 +179,8 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max( -- Aggregate with multiple distinct decimal columns SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col); + +-- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function +SELECT not(a IS NULL), count(*) AS c +FROM testData +GROUP BY a IS NULL diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 978087485871..cc52151fb395 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 62 +-- Number of queries: 63 -- !query @@ -642,3 +642,14 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 struct -- !query output 1.0000 1 + + +-- !query +SELECT not(a IS NULL), count(*) AS c +FROM testData +GROUP BY a IS NULL +-- !query schema +struct<(NOT (a IS NULL)):boolean,c:bigint> +-- !query output +false 2 +true 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 283f19785141..7e7853e1799d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4140,20 +4140,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } - - test("SPARK-34581: Don't optimize out grouping expressions from aggregate expressions") { - withTempView("t") { - Seq[Integer](null, 1, 2, 3, null).toDF("id").createOrReplaceTempView("t") - - val df = spark.sql( - """ - |SELECT not(t.id IS NULL), count(*) AS c - |FROM t - |GROUP BY t.id IS NULL - |""".stripMargin) - checkAnswer(df, Row(true, 3) :: Row(false, 2) :: Nil) - } - } } case class Foo(bar: Option[String]) From f46b89d6d81be4a7867fd3bc9edeccc1f7e9aea5 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 24 Mar 2021 12:14:16 +0100 Subject: [PATCH 07/15] add more explanation --- .../sql/catalyst/optimizer/finishAnalysis.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 6fa9f9062a9a..2166a34421d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -67,12 +67,19 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { } /** - * Wrap complex grouping expression in aggregate expressions without aggregate function into - * `GroupingExpression` nodes so as to avoid further optimizations between the expression and its - * parent. - * - * This is required as further optimizations could change the grouping expression and so make the + * Wrap some of the grouping expressions in aggregate expressions without aggregate functions into + * `GroupingExpression` nodes so as to avoid optimizations between the expression and its parent. + * This is required as optimizations could change these grouping expressions and so make the * aggregate expression invalid. + * We only need to wrap complex expressions (expressions with children so they are more than just + * an attribute or a literal) which can be subject of optimizations. + * + * For example, in the following query Spark shouldn't optimize the aggregate expression + * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`: + * SELECT not(c IS NULL) + * FROM t + * GROUP BY c IS NULL + * This rule changes the aggregate expression to `Not(GroupingExpression(IsNull(c)))`. */ object WrapGroupingExpressions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { From 977c0bff247987eec7efd68bea8b33236b0c1612 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 27 Mar 2021 18:01:43 +0100 Subject: [PATCH 08/15] new GroupingExprRef approach --- .../sql/catalyst/analysis/Analyzer.scala | 38 +++--- .../analysis/DeduplicateRelations.scala | 2 +- .../UnsupportedOperationChecker.scala | 3 +- ...lability.scala => UpdateNullability.scala} | 23 +++- .../spark/sql/catalyst/dsl/package.scala | 5 +- .../expressions/constraintExpressions.scala | 6 - .../sql/catalyst/expressions/grouping.scala | 19 +++ .../sql/catalyst/optimizer/ComplexTypes.scala | 11 +- ...nforceGroupingReferencesInAggregates.scala | 34 +++++ .../sql/catalyst/optimizer/Optimizer.scala | 49 +++++-- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../catalyst/optimizer/finishAnalysis.scala | 35 ----- .../sql/catalyst/parser/AstBuilder.scala | 8 +- .../plans/logical/basicLogicalOperators.scala | 128 ++++++++++++++++-- .../sql/catalyst/plans/logical/package.scala | 26 ++++ .../ResolveGroupingAnalyticsSuite.scala | 84 ++++++------ .../optimizer/AggregateOptimizeSuite.scala | 14 +- ...EliminateSortsBeforeRepartitionSuite.scala | 8 +- .../optimizer/complexTypesSuite.scala | 8 -- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 6 +- .../execution/python/ExtractPythonUDFs.scala | 7 +- .../approved-plans-v1_4/q62.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q62/explain.txt | 2 +- .../approved-plans-v1_4/q99.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q99/explain.txt | 2 +- 27 files changed, 357 insertions(+), 173 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/{UpdateAttributeNullability.scala => UpdateNullability.scala} (73%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4b8fc339370d..d1621d0c48aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -430,7 +430,7 @@ class Analyzer(override val catalogManager: CatalogManager) def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => - Aggregate(groups, assignAliases(aggs), child) + Aggregate(groups, assignAliases(aggs), child, false) case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => @@ -599,7 +599,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aggregations = constructAggregateExprs( finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid) - Aggregate(groupingAttrs, aggregations, expand) + Aggregate(groupingAttrs, aggregations, expand, false) } private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { @@ -746,14 +746,15 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => Alias(pivotColumn, "__pivot_col")() } val bigGroup = groupByExprs :+ namedPivotCol - val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) + val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child, false) val pivotAggs = namedAggExps.map { a => Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } val groupByExprsAttr = groupByExprs.map(_.toAttribute) - val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) + val secondAgg = + Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg, false) val pivotAggAttribute = pivotAggs.map(_.toAttribute) val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => @@ -790,7 +791,7 @@ class Analyzer(override val catalogManager: CatalogManager) Alias(filteredAggregate, outputName(value, aggregate))() } } - Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child, false) } } @@ -1406,7 +1407,8 @@ class Analyzer(override val catalogManager: CatalogManager) if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError() } else { - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + a.copy(aggrExprWithGroupingRefs = + buildExpandedProjectList(a.aggregateExpressions, a.child)) } // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => @@ -1819,7 +1821,7 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size, ordinal) case o => o } - Aggregate(newGroups, aggs, child) + Aggregate(newGroups, aggs, child, false) } } @@ -1917,7 +1919,8 @@ class Analyzer(override val catalogManager: CatalogManager) val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. - (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + (newExprs, + a.copy(aggrExprWithGroupingRefs = aggExprs ++ missingAttrs, child = newChild)) } else { // Need to add non-grouping attributes, invalid case. (exprs, a) @@ -2238,7 +2241,7 @@ class Analyzer(override val catalogManager: CatalogManager) object GlobalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => - Aggregate(Nil, projectList, child) + Aggregate(Nil, projectList, child, false) } def containsAggregates(exprs: Seq[Expression]): Boolean = { @@ -2287,7 +2290,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregateWithExtraOrdering = aggregate.copy( - aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) + aggrExprWithGroupingRefs = aggregate.aggregateExpressions ++ aliasedOrdering) val resolvedAggregate: Aggregate = executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] @@ -2341,7 +2344,7 @@ class Analyzer(override val catalogManager: CatalogManager) } else { Project(aggregate.output, Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + aggregate.copy(aggrExprWithGroupingRefs = originalAggExprs ++ needsPushDown))) } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, @@ -2368,7 +2371,8 @@ class Analyzer(override val catalogManager: CatalogManager) Aggregate( agg.groupingExpressions, Alias(filterCond, "havingCondition")() :: Nil, - agg.child) + agg.child, + false) val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator @@ -2423,7 +2427,7 @@ class Analyzer(override val catalogManager: CatalogManager) val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get Project(agg.output, Filter(resolvedHavingCond, - agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + agg.copy(aggrExprWithGroupingRefs = agg.aggregateExpressions ++ aggregateExpressions))) } else { filter } @@ -2553,7 +2557,7 @@ class Analyzer(override val catalogManager: CatalogManager) other :: Nil } - val newAgg = Aggregate(groupList, newAggList, child) + val newAgg = Aggregate(groupList, newAggList, child, false) Project(projectExprs.toList, newAgg) case p @ Project(projectList, _) if hasAggFunctionInGenerator(projectList) => @@ -2863,7 +2867,7 @@ class Analyzer(override val catalogManager: CatalogManager) a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. - val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, false) // Add a Filter operator for conditions in the Having clause. val withFilter = Filter(condition, withAggregate) val withWindow = addWindow(windowExpressions, withFilter) @@ -2880,7 +2884,7 @@ class Analyzer(override val catalogManager: CatalogManager) a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. - val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, false) // Add Window operators. val withWindow = addWindow(windowExpressions, withAggregate) @@ -3538,7 +3542,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { case Aggregate(grouping, aggs, child) => val cleanedAggs = aggs.map(trimNonTopLevelAliases) - Aggregate(grouping.map(trimAliases), cleanedAggs, child) + Aggregate(grouping.map(trimAliases), cleanedAggs, child, false) case Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index fdd9df061b5f..3213cd024756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -167,7 +167,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy( - aggregateExpressions = newAliases(aggregateExpressions)))) + aggrExprWithGroupingRefs = newAliases(aggregateExpressions)))) // We don't search the child plan recursively for the same reason as the above Project. case _ @ Aggregate(_, aggregateExpressions, _) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 0d586e726191..f3bcc4fb1ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -107,7 +107,8 @@ object UnsupportedOperationChecker extends Logging { // Since the Distinct node will be replaced to Aggregate in the optimizer rule // [[ReplaceDistinctWithAggregate]], here we also need to check all Distinct node by // assuming it as Aggregate. - case d @ Distinct(c: LogicalPlan) if d.isStreaming => Aggregate(c.output, c.output, c) + case d @ Distinct(c: LogicalPlan) if d.isStreaming => + Aggregate(c.output, c.output, c, false) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala similarity index 73% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala index 5004108d348b..3c2a18d70edc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** @@ -52,3 +52,22 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of + * referenced grouping expression. + */ +object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case a: Aggregate => + val nullabilities = a.groupingExpressions.map(_.nullable).toArray + + val newAggrExprWithGroupingRefs = + a.aggrExprWithGroupingRefs.map(_.transform { + case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) => + g.copy(nullable = nullabilities(g.ordinal)) + }.asInstanceOf[NamedExpression]) + + a.copy(aggrExprWithGroupingRefs = newAggrExprWithGroupingRefs) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 626ece33f157..611133da7ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.optimizer.EnforceGroupingReferencesInAggregates import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -407,7 +408,7 @@ package object dsl { case ne: NamedExpression => ne case e => Alias(e, e.toString)() } - Aggregate(groupingExprs, aliasedExprs, logicalPlan) + Aggregate(groupingExprs, aliasedExprs, logicalPlan, false) } def having( @@ -466,7 +467,7 @@ package object dsl { def analyze: LogicalPlan = { val analyzed = analysis.SimpleAnalyzer.execute(logicalPlan) analysis.SimpleAnalyzer.checkAnalysis(analyzed) - EliminateSubqueryAliases(analyzed) + EnforceGroupingReferencesInAggregates(EliminateSubqueryAliases(analyzed)) } def hint(name: String, parameters: Any*): LogicalPlan = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index cdaed2bf404e..5bfae7b77e09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -39,9 +39,3 @@ case class KnownNotNull(child: Expression) extends TaggingExpression { } case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression - -/** - * Wrapper expression to avoid further optimizations between the parent and child of this - * expression. - */ -case class GroupingExpression(child: Expression) extends TaggingExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index c6b67d62d181..50879aa23d43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -212,3 +212,22 @@ object GroupingID { if (SQLConf.get.integerGroupingIdEnabled) IntegerType else LongType } } + +/** + * A reference to an grouping expression in [[Aggregate]] node. + * + * @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression + * refers to. + * @param dataType The [[DataType]] of the referenced grouping expression. + * @param nullable True if null is a valid value for the referenced grouping expression. + */ +case class GroupingExprRef( + ordinal: Int, + dataType: DataType, + nullable: Boolean) + extends LeafExpression with Unevaluable { + + override def stringArgs: Iterator[Any] = { + Iterator(ordinal) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 0ff11ca49f3d..8f1548a9788a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** @@ -26,15 +26,6 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object SimplifyExtractValueOps extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // One place where this optimization is invalid is an aggregation where the select - // list expression is a function of a grouping expression: - // - // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) - // - // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this - // optimization for Aggregates (although this misses some cases where the optimization - // can be made). - case a: Aggregate => a case p => p.transformExpressionsUp { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala new file mode 100644 index 000000000000..de42b3631893 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This rule ensures that [[Aggregate]] nodes contain all required [[GroupingExprRef]] + * references for optimization phase. + */ +object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case a: Aggregate if !a.enforceGroupingReferences => + Aggregate(a.groupingExpressions, a.aggrExprWithGroupingRefs, a.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index cb8b6ab748de..216c8fbd0806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,7 +118,8 @@ abstract class Optimizer(catalogManager: CatalogManager) OptimizeUpdateFields, SimplifyExtractValueOps, OptimizeCsvJsonExprs, - CombineConcats) ++ + CombineConcats, + UpdateGroupingExprRefNullability) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { @@ -147,7 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView, ReplaceExpressions, RewriteNonCorrelatedExists, - WrapGroupingExpressions, + EnforceGroupingReferencesInAggregates, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager)) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -507,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => val aliasMap = getAliasMap(lower) - val newAggregate = upper.copy( + val newAggregate = Aggregate( child = lower.child, groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), aggregateExpressions = upper.aggregateExpressions.map( @@ -751,8 +752,8 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) => - p.copy( - child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + p.copy(child = + a.copy(aggrExprWithGroupingRefs = a.aggrExprWithGroupingRefs.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -878,8 +879,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - agg.copy(aggregateExpressions = buildCleanedProjectList( - p.projectList, agg.aggregateExpressions)) + Aggregate(agg.groupingExpressions, + buildCleanedProjectList(p.projectList, agg.aggregateExpressions), agg.child) } case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) if isRenaming(l1, l2) => @@ -1249,6 +1250,7 @@ object EliminateSorts extends Rule[LogicalPlan] { def checkValidAggregateExpression(expr: Expression): Boolean = expr match { case _: AttributeReference => true + case _: GroupingExprRef => true case ae: AggregateExpression => isOrderIrrelevantAggFunction(ae.aggregateFunction) case _: UserDefinedExpression => false case e => e.children.forall(checkValidAggregateExpression) @@ -1980,7 +1982,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) if (newGrouping.nonEmpty) { - a.copy(groupingExpressions = newGrouping) + val droppedGroupsBefore = + grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray + + val newAggrExprWithGroupingReferences = + a.aggrExprWithGroupingRefs.map(_.transform { + case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => + g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) + }.asInstanceOf[NamedExpression]) + + a.copy( + groupingExpressions = newGrouping, + aggrExprWithGroupingRefs = newAggrExprWithGroupingReferences) } else { // All grouping expressions are literals. We should not drop them all, because this can // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We @@ -2001,7 +2014,25 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { if (newGrouping.size == grouping.size) { a } else { - a.copy(groupingExpressions = newGrouping) + var i = 0 + val droppedGroupsBefore = grouping.scanLeft(0)((n, e) => + n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) { + i += 1 + 0 + } else { + 1 + }) + ).toArray + + val newAggrExprWithGroupingReferences = + a.aggrExprWithGroupingRefs.map(_.transform { + case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => + g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) + }.asInstanceOf[NamedExpression]) + + a.copy( + groupingExpressions = newGrouping, + aggrExprWithGroupingRefs = newAggrExprWithGroupingReferences) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index c3d2f336f06f..77f0c719e426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -795,7 +795,7 @@ object NullPropagation extends Rule[LogicalPlan] { */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - CleanupAliases(propagateFoldables(plan)._1) + EnforceGroupingReferencesInAggregates(CleanupAliases(propagateFoldables(plan)._1)) } private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 2166a34421d9..897d74a0f947 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -66,41 +66,6 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { } } -/** - * Wrap some of the grouping expressions in aggregate expressions without aggregate functions into - * `GroupingExpression` nodes so as to avoid optimizations between the expression and its parent. - * This is required as optimizations could change these grouping expressions and so make the - * aggregate expression invalid. - * We only need to wrap complex expressions (expressions with children so they are more than just - * an attribute or a literal) which can be subject of optimizations. - * - * For example, in the following query Spark shouldn't optimize the aggregate expression - * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`: - * SELECT not(c IS NULL) - * FROM t - * GROUP BY c IS NULL - * This rule changes the aggregate expression to `Not(GroupingExpression(IsNull(c)))`. - */ -object WrapGroupingExpressions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { - case a: Aggregate => - val complexGroupingExpressions = - ExpressionSet(a.groupingExpressions.filter(_.children.nonEmpty)) - - def wrapGroupingExpression(e: Expression): Expression = e match { - case _: GroupingExpression => e - case _: AggregateExpression => e - case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e) - case _ => e.mapChildren(wrapGroupingExpression) - } - - a.copy(aggregateExpressions = - a.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression])) - } - } -} - /** * Computes the current date and time to make sure we return the same result in a single query. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index dc87398fc6a4..2db7cd7f8d7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -726,7 +726,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg Filter(predicate, createProject()) } else { // According to SQL standard, HAVING without GROUP BY means global aggregate. - withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter)) + withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter, false)) } } else if (aggregationClause != null) { val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter) @@ -924,7 +924,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val groupingSets = ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)), - selectExpressions, query) + selectExpressions, query, false) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? val mappedGroupByExpressions = if (ctx.CUBE != null) { @@ -934,7 +934,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } else { groupByExpressions } - Aggregate(mappedGroupByExpressions, selectExpressions, query) + Aggregate(mappedGroupByExpressions, selectExpressions, query, false) } } else { val groupByExpressions = @@ -978,7 +978,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg "`GROUP BY CUBE(a, b), ROLLUP(a, c)` is not supported.", ctx) } - Aggregate(groupByExpressions.toSeq, selectExpressions, query) + Aggregate(groupByExpressions.toSeq, selectExpressions, query, false) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9461dbf9f3d1..b9b993e06ecf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} @@ -622,31 +624,47 @@ case class Range( /** * This is a Group by operator with the aggregate functions and projections. * - * @param groupingExpressions expressions for grouping keys - * @param aggregateExpressions expressions for a project list, which could contain - * [[AggregateExpression]]s. + * @param groupingExpressions Expressions for grouping keys. + * @param aggrExprWithGroupingRefs Expressions for a project list, which could contain + * [[AggregateExpression]]s and [[GroupingExprRef]]s. + * @param child The child of the aggregate node. + * @param enforceGroupingReferences If [[aggrExprWithGroupingRefs]] should contain + * [[GroupingExprRef]]s. + * + * [[aggrExprWithGroupingRefs]] without aggregate functions can contain [[GroupingExprRef]] + * expressions to refer to complex grouping expressions in [[groupingExpressions]]. These references + * ensure that optimization rules don't change the aggregate expressions to invalid ones that no + * longer refer to any grouping expressions and also simplify the expression transformations on the + * node (need to transform the expression only once). * - * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before - * separating projection from grouping and aggregate, we should avoid expression-level optimization - * on aggregateExpressions, which could reference an expression in groupingExpressions. - * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]] + * For example, in the following query Spark shouldn't optimize the aggregate expression + * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`: + * SELECT not(c IS NULL) + * FROM t + * GROUP BY c IS NULL + * Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`. */ -case class Aggregate( +case class AggregateWithGroupingReferences( groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: LogicalPlan) + aggrExprWithGroupingRefs: Seq[NamedExpression], + child: LogicalPlan, + enforceGroupingReferences: Boolean) extends UnaryNode { + override val nodeName = "Aggregate" + + override val stringArgs = + Iterator(groupingExpressions, aggrExprWithGroupingRefs, child) override lazy val resolved: Boolean = { - val hasWindowExpressions = aggregateExpressions.exists ( _.collect { - case window: WindowExpression => window - }.nonEmpty - ) + val hasWindowExpressions = aggrExprWithGroupingRefs.exists(_.collect { + case window: WindowExpression => window + }.nonEmpty) !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggrExprWithGroupingRefs.map(_.toAttribute) + override def maxRows: Option[Long] = { if (groupingExpressions.isEmpty) { Some(1L) @@ -655,12 +673,92 @@ case class Aggregate( } } + private def expandGroupingReferences(e: Expression): Expression = { + e match { + case _: AggregateExpression => e + case _ if PythonUDF.isGroupedAggPandasUDF(e) => e + case g: GroupingExprRef => groupingExpressions(g.ordinal) + case _ => e.mapChildren(expandGroupingReferences) + } + } + + lazy val aggregateExpressions = { + if (enforceGroupingReferences) { + aggrExprWithGroupingRefs.map(expandGroupingReferences(_).asInstanceOf[NamedExpression]) + } else { + aggrExprWithGroupingRefs + } + } + override lazy val validConstraints: ExpressionSet = { val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) getAllValidConstraints(nonAgg) } } +object Aggregate { + private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = { + groupingExpressions.zipWithIndex + .foldLeft(mutable.Map.empty[Expression, (Expression, Int)]) { + case (m, (ge, i)) => + if (ge.deterministic && !ge.foldable && ge.children.nonEmpty && + !m.contains(ge.canonicalized)) { + m += ge.canonicalized -> (ge, i) + } + m + } + } + + private def insertGroupingReferences( + aggregateExpressions: Seq[NamedExpression], + groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = { + def insertGroupingExprRefs(e: Expression): Expression = { + e match { + case _ if !e.deterministic => e + case _: AggregateExpression => e + case _ if PythonUDF.isGroupedAggPandasUDF(e) => e + case _ if groupingExpressions.contains(e.canonicalized) => + val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized) + GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable) + case _ => e.mapChildren(insertGroupingExprRefs) + } + } + + aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression]) + } + + def apply( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan, + enforceGroupingReferences: Boolean = true): Aggregate = { + val (newGroupingExpressions, aggrExprWithGroupingReferences) = + if (enforceGroupingReferences) { + val dealiasedGroupingExpressions = groupingExpressions.map { + case a: Alias => a.child + case o => o + } + val complexGroupingExpressions = + collectComplexGroupingExpressions(dealiasedGroupingExpressions) + val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) { + insertGroupingReferences(aggregateExpressions, complexGroupingExpressions) + } else { + aggregateExpressions + } + (dealiasedGroupingExpressions, aggrExprWithGroupingReferences) + } else { + (groupingExpressions, aggregateExpressions) + } + new Aggregate(newGroupingExpressions, aggrExprWithGroupingReferences, child, + enforceGroupingReferences) + } + + def unapply( + aggregate: Aggregate): Option[(Seq[Expression], Seq[NamedExpression], LogicalPlan)] = { + Some(aggregate.groupingExpressions, aggregate.aggregateExpressions, aggregate.child) + } +} + case class Window( windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala new file mode 100644 index 000000000000..88b783b004ce --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +package object logical { + /** + * This Aggregate type alias together with Aggregate object keeps the semantics of Aggregate after + * the Aggregate node was renamed to AggregateWithGroupingReferences. + */ + type Aggregate = AggregateWithGroupingReferences +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 81fc86420e1e..dcb7865f5a1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -70,40 +70,40 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan = Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan, expected) val originalPlan2 = Aggregate( Seq(GroupingSets(Seq(), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan2, expected2) val originalPlan3 = Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b), Seq(unresolved_c)), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) } test("grouping sets with no explicit group by expressions") { val originalPlan = Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Nil)), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan, expected) // Computation of grouping expression should remove duplicate expression based on their @@ -112,7 +112,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil)), Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), - unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) val resultPlan = getAnalyzer.executeAndCheck(originalPlan2, new QueryPlanningTracker) val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions @@ -126,40 +126,42 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { test("cube") { val originalPlan = Aggregate(Seq(Cube(Seq(Seq(unresolved_a), Seq(unresolved_b)))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan, expected) - val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) + val originalPlan2 = + Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1, false) val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, 0L)), Seq(a, b, c, gid), - Project(Seq(a, b, c), r1))) + Project(Seq(a, b, c), r1)), false) checkAnalysis(originalPlan2, expected2) } test("rollup") { val originalPlan = Aggregate(Seq(Rollup(Seq(Seq(unresolved_a), Seq(unresolved_b)))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan, expected) - val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) + val originalPlan2 = + Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1, false) val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, 0L)), Seq(a, b, c, gid), - Project(Seq(a, b, c), r1))) + Project(Seq(a, b, c), r1)), false) checkAnalysis(originalPlan2, expected2) } @@ -169,38 +171,38 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(Grouping(unresolved_a))), r1) + UnresolvedAlias(Grouping(unresolved_a))), r1, false) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan, expected) // Cube val originalPlan2 = Aggregate(Seq(Cube(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(Grouping(unresolved_a))), r1) + UnresolvedAlias(Grouping(unresolved_a))), r1, false) val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan2, expected2) // Rollup val originalPlan3 = Aggregate(Seq(Rollup(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(Grouping(unresolved_a))), r1) + UnresolvedAlias(Grouping(unresolved_a))), r1, false) val expected3 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan3, expected3) } @@ -210,38 +212,38 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1, false) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan, expected) // Cube val originalPlan2 = Aggregate(Seq(Cube(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1, false) val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan2, expected2) // Rollup val originalPlan3 = Aggregate(Seq(Rollup(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1, false) val expected3 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) checkAnalysis(originalPlan3, expected3) } @@ -251,7 +253,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b), r1)) + Seq(unresolved_a, unresolved_b), r1, false)) val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType, Option(TimeZone.getDefault().getID)) === 0, Aggregate(Seq(a, b, gid), @@ -260,11 +262,11 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) checkAnalysis(originalPlan, expected) val originalPlan2 = Filter(Grouping(unresolved_a) === 0, - Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1, false)) assertAnalysisError(originalPlan2, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) @@ -272,7 +274,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1L, Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1)) + Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1, false)) val expected3 = Project(Seq(a, b), Filter(gid === 1L, Aggregate(Seq(a, b, gid), Seq(a, b, gid), @@ -280,11 +282,11 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) checkAnalysis(originalPlan3, expected3) val originalPlan4 = Filter(GroupingID(Seq(unresolved_a)) === 1, - Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1, false)) assertAnalysisError(originalPlan4, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) } @@ -295,7 +297,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1)) + Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1, false)) val expected = Project(Seq(a, b), Sort( Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), @@ -304,11 +306,12 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) checkAnalysis(originalPlan, expected) val originalPlan2 = Sort(Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, - Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1, + false)) assertAnalysisError(originalPlan2, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) @@ -316,7 +319,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan3 = Sort( Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true, Aggregate(Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1)) + Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1, false)) val expected3 = Project(Seq(a, b), Sort( Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), @@ -325,12 +328,13 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) checkAnalysis(originalPlan3, expected3) val originalPlan4 = Sort( Seq(SortOrder(GroupingID(Seq(unresolved_a)), Ascending)), true, - Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1, + false)) assertAnalysisError(originalPlan4, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 8984bad479a6..a8aba39e389b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -29,10 +29,13 @@ class AggregateOptimizeSuite extends AnalysisTest { val analyzer = getAnalyzer object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("Aggregate", FixedPoint(100), - FoldablePropagation, - RemoveLiteralFromGroupExpressions, - RemoveRepetitionFromGroupExpressions) :: Nil + val batches = + Batch("Finish Analysis", Once, + EnforceGroupingReferencesInAggregates) :: + Batch("Aggregate", FixedPoint(100), + FoldablePropagation, + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -50,7 +53,8 @@ class AggregateOptimizeSuite extends AnalysisTest { val analyzer = getAnalyzer val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) + val correctAnswer = EnforceGroupingReferencesInAggregates( + analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))) comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala index 82db174ad41b..b5cbdd906ad8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala @@ -34,7 +34,9 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { val anotherTestRelation = LocalRelation('d.int, 'e.int) object Optimize extends RuleExecutor[LogicalPlan] { - val batches = + val batches = { + Batch("Finish Analysis", Once, + EnforceGroupingReferencesInAggregates) :: Batch("Default", FixedPoint(10), FoldablePropagation, LimitPushDown) :: @@ -42,6 +44,7 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { EliminateSorts) :: Batch("Collapse Project", Once, CollapseProject) :: Nil + } } def repartition(plan: LogicalPlan): LogicalPlan = plan.repartition(10) @@ -139,7 +142,8 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { val optimizedPlan = testRelation.distribute('a)(2).where('a === 10) val aggPlan = plan.groupBy('a)(sum('b)) val optimizedAggPlan = optimize(aggPlan) - val correctAggPlan = analyze(optimizedPlan.groupBy('a)(sum('b))) + val correctAggPlan = + EnforceGroupingReferencesInAggregates(analyze(optimizedPlan.groupBy('a)(sum('b)))) comparePlans(optimizedAggPlan, correctAggPlan) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index dcd2fbbf0052..49cd3acd8242 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -405,14 +405,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val arrayAggRel = relation.groupBy( CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) checkRule(arrayAggRel, arrayAggRel) - - // This could be done if we had a more complex rule that checks that - // the CreateMap does not come from key. - val originalQuery = relation - .groupBy('id)( - GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" - ) - checkRule(originalQuery, originalQuery) } test("SPARK-23500: namedStruct and getField in the same Project #1") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8cd8ec8bd801..01850ead92c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -339,11 +339,11 @@ class PlanParserSuite extends AnalysisTest { // Grouping Sets assertEqual(s"$sql grouping sets((a, b), (a), ())", Aggregate(Seq(GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b))), - Seq('a, 'b, 'sum.function('c).as("c")), table("d"))) + Seq('a, 'b, 'sum.function('c).as("c")), table("d"), false)) assertEqual(s"$sqlWithoutGroupBy group by grouping sets((a, b), (a), ())", Aggregate(Seq(GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b))), - Seq('a, 'b, 'sum.function('c).as("c")), table("d"))) + Seq('a, 'b, 'sum.function('c).as("c")), table("d"), false)) val m = intercept[ParseException] { parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 76ee297dfca7..ad98fc757a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -446,7 +446,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( } else { Alias(CreateStruct(groupingAttributes), "key")() } - val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan, false) val execution = new QueryExecution(sparkSession, aggregate) new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index bf3f7058b9bd..c8c964e928ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -66,15 +66,15 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => - Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan, false)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( df.sparkSession, Aggregate(Seq(Rollup(groupingExprs.map(Seq(_)))), - aliasedAgg, df.logicalPlan)) + aliasedAgg, df.logicalPlan, false)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))), - aliasedAgg, df.logicalPlan)) + aliasedAgg, df.logicalPlan, false)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index b79bcd176b7b..d68c5f03b7c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -72,7 +72,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } } // There is no Python UDF over aggregate expression - Project(projList.toSeq, agg.copy(aggregateExpressions = aggExpr.toSeq)) + Project(projList.toSeq, Aggregate(agg.groupingExpressions, aggExpr.toSeq, agg.child)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -133,10 +133,7 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { attributeMap.getOrElse(canonicalized, p) }.asInstanceOf[NamedExpression] } - agg.copy( - groupingExpressions = groupingExpr.toSeq, - aggregateExpressions = aggExpr, - child = Project((projList ++ agg.child.output).toSeq, agg.child)) + Aggregate(groupingExpr.toSeq, aggExpr, Project((projList ++ agg.child.output).toSeq, agg.child)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt index a876fb70a86b..41462a738912 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62.sf100/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#13, web_name#10, sum# Keys [3]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#13, web_name#10] Functions [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [groupingexpression(substr(w_warehouse_name#16, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#13, web_name#10, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [substr(w_warehouse_name#16, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#13, web_name#10, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#13, web_name#10, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt index 802c9a5f3c3a..b5aa53f2de12 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q62/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, web_name#13, sum#2 Keys [3]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, web_name#13] Functions [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [groupingexpression(substr(w_warehouse_name#7, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, web_name#13, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [substr(w_warehouse_name#7, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, web_name#13, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 30) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 60) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 90) AND ((ws_ship_date_sk#1 - ws_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((ws_ship_date_sk#1 - ws_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#10, web_name#13, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt index 76553805e743..9d653586e519 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99.sf100/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#10, cc_name#13, sum#2 Keys [3]: [substr(w_warehouse_name#16, 1, 20)#18, sm_type#10, cc_name#13] Functions [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [groupingexpression(substr(w_warehouse_name#16, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [substr(w_warehouse_name#16, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt index ff180769e859..c05e9595220c 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q99/explain.txt @@ -175,7 +175,7 @@ Input [8]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, cc_name#13, sum#24 Keys [3]: [substr(w_warehouse_name#7, 1, 20)#18, sm_type#10, cc_name#13] Functions [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END), sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END), sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)] Aggregate Attributes [5]: [sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34] -Results [8]: [groupingexpression(substr(w_warehouse_name#7, 1, 20)#18) AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] +Results [8]: [substr(w_warehouse_name#7, 1, 20)#18 AS substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 30) THEN 1 ELSE 0 END)#30 AS 30 days #36, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 30) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 60)) THEN 1 ELSE 0 END)#31 AS 31 - 60 days #37, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 60) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 90)) THEN 1 ELSE 0 END)#32 AS 61 - 90 days #38, sum(CASE WHEN (((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 90) AND ((cs_ship_date_sk#1 - cs_sold_date_sk#5) <= 120)) THEN 1 ELSE 0 END)#33 AS 91 - 120 days #39, sum(CASE WHEN ((cs_ship_date_sk#1 - cs_sold_date_sk#5) > 120) THEN 1 ELSE 0 END)#34 AS >120 days #40] (32) TakeOrderedAndProject Input [8]: [substr(w_warehouse_name, 1, 20)#35, sm_type#10, cc_name#13, 30 days #36, 31 - 60 days #37, 61 - 90 days #38, 91 - 120 days #39, >120 days #40] From c2ba80457bd86d11ad26311bbc3c42607f33b19a Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 11 Apr 2021 10:20:50 +0200 Subject: [PATCH 09/15] simplify --- .../sql/catalyst/analysis/Analyzer.scala | 38 ++++----- .../analysis/DeduplicateRelations.scala | 2 +- .../UnsupportedOperationChecker.scala | 3 +- .../catalyst/analysis/UpdateNullability.scala | 6 +- .../spark/sql/catalyst/dsl/package.scala | 5 +- .../catalyst/expressions/AliasHelper.scala | 2 +- ...nforceGroupingReferencesInAggregates.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 27 +++--- .../sql/catalyst/optimizer/subquery.scala | 5 +- .../sql/catalyst/parser/AstBuilder.scala | 8 +- .../sql/catalyst/planning/patterns.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 85 +++++++------------ .../sql/catalyst/plans/logical/package.scala | 26 ------ .../ResolveGroupingAnalyticsSuite.scala | 84 +++++++++--------- .../optimizer/AggregateOptimizeSuite.scala | 14 ++- ...EliminateSortsBeforeRepartitionSuite.scala | 8 +- .../RemoveRedundantAggregatesSuite.scala | 2 +- .../optimizer/complexTypesSuite.scala | 4 +- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 6 +- .../execution/python/ExtractPythonUDFs.scala | 21 ++--- 22 files changed, 142 insertions(+), 218 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d1621d0c48aa..4b8fc339370d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -430,7 +430,7 @@ class Analyzer(override val catalogManager: CatalogManager) def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => - Aggregate(groups, assignAliases(aggs), child, false) + Aggregate(groups, assignAliases(aggs), child) case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => @@ -599,7 +599,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aggregations = constructAggregateExprs( finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid) - Aggregate(groupingAttrs, aggregations, expand, false) + Aggregate(groupingAttrs, aggregations, expand) } private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { @@ -746,15 +746,14 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => Alias(pivotColumn, "__pivot_col")() } val bigGroup = groupByExprs :+ namedPivotCol - val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child, false) + val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) val pivotAggs = namedAggExps.map { a => Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } val groupByExprsAttr = groupByExprs.map(_.toAttribute) - val secondAgg = - Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg, false) + val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) val pivotAggAttribute = pivotAggs.map(_.toAttribute) val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => @@ -791,7 +790,7 @@ class Analyzer(override val catalogManager: CatalogManager) Alias(filteredAggregate, outputName(value, aggregate))() } } - Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child, false) + Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } @@ -1407,8 +1406,7 @@ class Analyzer(override val catalogManager: CatalogManager) if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError() } else { - a.copy(aggrExprWithGroupingRefs = - buildExpandedProjectList(a.aggregateExpressions, a.child)) + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => @@ -1821,7 +1819,7 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size, ordinal) case o => o } - Aggregate(newGroups, aggs, child, false) + Aggregate(newGroups, aggs, child) } } @@ -1919,8 +1917,7 @@ class Analyzer(override val catalogManager: CatalogManager) val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. - (newExprs, - a.copy(aggrExprWithGroupingRefs = aggExprs ++ missingAttrs, child = newChild)) + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) } else { // Need to add non-grouping attributes, invalid case. (exprs, a) @@ -2241,7 +2238,7 @@ class Analyzer(override val catalogManager: CatalogManager) object GlobalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => - Aggregate(Nil, projectList, child, false) + Aggregate(Nil, projectList, child) } def containsAggregates(exprs: Seq[Expression]): Boolean = { @@ -2290,7 +2287,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregateWithExtraOrdering = aggregate.copy( - aggrExprWithGroupingRefs = aggregate.aggregateExpressions ++ aliasedOrdering) + aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) val resolvedAggregate: Aggregate = executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] @@ -2344,7 +2341,7 @@ class Analyzer(override val catalogManager: CatalogManager) } else { Project(aggregate.output, Sort(finalSortOrders, global, - aggregate.copy(aggrExprWithGroupingRefs = originalAggExprs ++ needsPushDown))) + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, @@ -2371,8 +2368,7 @@ class Analyzer(override val catalogManager: CatalogManager) Aggregate( agg.groupingExpressions, Alias(filterCond, "havingCondition")() :: Nil, - agg.child, - false) + agg.child) val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator @@ -2427,7 +2423,7 @@ class Analyzer(override val catalogManager: CatalogManager) val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get Project(agg.output, Filter(resolvedHavingCond, - agg.copy(aggrExprWithGroupingRefs = agg.aggregateExpressions ++ aggregateExpressions))) + agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) } else { filter } @@ -2557,7 +2553,7 @@ class Analyzer(override val catalogManager: CatalogManager) other :: Nil } - val newAgg = Aggregate(groupList, newAggList, child, false) + val newAgg = Aggregate(groupList, newAggList, child) Project(projectExprs.toList, newAgg) case p @ Project(projectList, _) if hasAggFunctionInGenerator(projectList) => @@ -2867,7 +2863,7 @@ class Analyzer(override val catalogManager: CatalogManager) a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. - val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, false) + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) // Add a Filter operator for conditions in the Having clause. val withFilter = Filter(condition, withAggregate) val withWindow = addWindow(windowExpressions, withFilter) @@ -2884,7 +2880,7 @@ class Analyzer(override val catalogManager: CatalogManager) a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. - val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, false) + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) // Add Window operators. val withWindow = addWindow(windowExpressions, withAggregate) @@ -3542,7 +3538,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { case Aggregate(grouping, aggs, child) => val cleanedAggs = aggs.map(trimNonTopLevelAliases) - Aggregate(grouping.map(trimAliases), cleanedAggs, child, false) + Aggregate(grouping.map(trimAliases), cleanedAggs, child) case Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 3213cd024756..fdd9df061b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -167,7 +167,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy( - aggrExprWithGroupingRefs = newAliases(aggregateExpressions)))) + aggregateExpressions = newAliases(aggregateExpressions)))) // We don't search the child plan recursively for the same reason as the above Project. case _ @ Aggregate(_, aggregateExpressions, _) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index f3bcc4fb1ef8..0d586e726191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -107,8 +107,7 @@ object UnsupportedOperationChecker extends Logging { // Since the Distinct node will be replaced to Aggregate in the optimizer rule // [[ReplaceDistinctWithAggregate]], here we also need to check all Distinct node by // assuming it as Aggregate. - case d @ Distinct(c: LogicalPlan) if d.isStreaming => - Aggregate(c.output, c.output, c, false) + case d @ Distinct(c: LogicalPlan) if d.isStreaming => Aggregate(c.output, c.output, c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala index 3c2a18d70edc..d9a52f3ccf2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala @@ -62,12 +62,12 @@ object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] { case a: Aggregate => val nullabilities = a.groupingExpressions.map(_.nullable).toArray - val newAggrExprWithGroupingRefs = - a.aggrExprWithGroupingRefs.map(_.transform { + val newAggregateExpressions = + a.aggregateExpressions.map(_.transform { case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) => g.copy(nullable = nullabilities(g.ordinal)) }.asInstanceOf[NamedExpression]) - a.copy(aggrExprWithGroupingRefs = newAggrExprWithGroupingRefs) + a.copy(aggregateExpressions = newAggregateExpressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 611133da7ee7..626ece33f157 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.optimizer.EnforceGroupingReferencesInAggregates import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -408,7 +407,7 @@ package object dsl { case ne: NamedExpression => ne case e => Alias(e, e.toString)() } - Aggregate(groupingExprs, aliasedExprs, logicalPlan, false) + Aggregate(groupingExprs, aliasedExprs, logicalPlan) } def having( @@ -467,7 +466,7 @@ package object dsl { def analyze: LogicalPlan = { val analyzed = analysis.SimpleAnalyzer.execute(logicalPlan) analysis.SimpleAnalyzer.checkAnalysis(analyzed) - EnforceGroupingReferencesInAggregates(EliminateSubqueryAliases(analyzed)) + EliminateSubqueryAliases(analyzed) } def hint(name: String, parameters: Any*): LogicalPlan = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index 1f3f76266225..e9673d7f20f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -35,7 +35,7 @@ trait AliasHelper { protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = { // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression or PythonUDF, and create a map from the alias to the expression - val aliasMap = plan.aggregateExpressions.collect { + val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect { case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(e)).isEmpty => (a.toAttribute, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala index de42b3631893..74042fcbc85b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan transform { - case a: Aggregate if !a.enforceGroupingReferences => - Aggregate(a.groupingExpressions, a.aggrExprWithGroupingRefs, a.child) + case a: Aggregate => + Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 216c8fbd0806..4a8ded1d3ffe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -508,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => val aliasMap = getAliasMap(lower) - val newAggregate = Aggregate( + val newAggregate = Aggregate.withGroupingRefs( child = lower.child, groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), aggregateExpressions = upper.aggregateExpressions.map( @@ -752,8 +752,8 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) => - p.copy(child = - a.copy(aggrExprWithGroupingRefs = a.aggrExprWithGroupingRefs.filter(p.references.contains))) + p.copy( + child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -879,8 +879,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - Aggregate(agg.groupingExpressions, - buildCleanedProjectList(p.projectList, agg.aggregateExpressions), agg.child) + agg.copy(aggregateExpressions = buildCleanedProjectList( + p.projectList, agg.aggregateExpressions)) } case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) if isRenaming(l1, l2) => @@ -1250,7 +1250,6 @@ object EliminateSorts extends Rule[LogicalPlan] { def checkValidAggregateExpression(expr: Expression): Boolean = expr match { case _: AttributeReference => true - case _: GroupingExprRef => true case ae: AggregateExpression => isOrderIrrelevantAggFunction(ae.aggregateFunction) case _: UserDefinedExpression => false case e => e.children.forall(checkValidAggregateExpression) @@ -1985,15 +1984,15 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { val droppedGroupsBefore = grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray - val newAggrExprWithGroupingReferences = - a.aggrExprWithGroupingRefs.map(_.transform { + val newAggregateExpressions = + a.aggregateExpressions.map(_.transform { case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) }.asInstanceOf[NamedExpression]) - a.copy( - groupingExpressions = newGrouping, - aggrExprWithGroupingRefs = newAggrExprWithGroupingReferences) + a.copy( + groupingExpressions = newGrouping, + aggregateExpressions = newAggregateExpressions) } else { // All grouping expressions are literals. We should not drop them all, because this can // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We @@ -2024,15 +2023,15 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { }) ).toArray - val newAggrExprWithGroupingReferences = - a.aggrExprWithGroupingRefs.map(_.transform { + val newAggregateExpressions = + a.aggregateExpressions.map(_.transform { case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) }.asInstanceOf[NamedExpression]) a.copy( groupingExpressions = newGrouping, - aggrExprWithGroupingRefs = newAggrExprWithGroupingReferences) + aggregateExpressions = newAggregateExpressions) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 05678b7bbdab..392fcd122b47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -612,9 +612,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe * subqueries. */ def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { - case a @ Aggregate(grouping, expressions, child) => + case a @ Aggregate(grouping, _, child) => val subqueries = ArrayBuffer.empty[ScalarSubquery] - val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs + .map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { // We currently only allow correlated subqueries in an aggregate if they are part of the // grouping expressions. As a result we need to replace all the scalar subqueries in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2db7cd7f8d7d..dc87398fc6a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -726,7 +726,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg Filter(predicate, createProject()) } else { // According to SQL standard, HAVING without GROUP BY means global aggregate. - withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter, false)) + withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter)) } } else if (aggregationClause != null) { val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter) @@ -924,7 +924,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val groupingSets = ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)), - selectExpressions, query, false) + selectExpressions, query) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? val mappedGroupByExpressions = if (ctx.CUBE != null) { @@ -934,7 +934,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } else { groupByExpressions } - Aggregate(mappedGroupByExpressions, selectExpressions, query, false) + Aggregate(mappedGroupByExpressions, selectExpressions, query) } } else { val groupByExpressions = @@ -978,7 +978,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg "`GROUP BY CUBE(a, b), ROLLUP(a, c)` is not supported.", ctx) } - Aggregate(groupByExpressions.toSeq, selectExpressions, query, false) + Aggregate(groupByExpressions.toSeq, selectExpressions, query) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 2880e87ab156..fd0adb2f39dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -287,7 +287,7 @@ object PhysicalAggregation { (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => + case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll // build a set of semantically distinct aggregate expressions and re-write expressions so @@ -322,7 +322,7 @@ object PhysicalAggregation { // which takes the grouping columns and final aggregate result buffer as input. // Thus, we must re-write the result expressions so that their attributes match up with // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => + val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr => expr.transformDown { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b9b993e06ecf..d07872537667 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -625,17 +625,15 @@ case class Range( * This is a Group by operator with the aggregate functions and projections. * * @param groupingExpressions Expressions for grouping keys. - * @param aggrExprWithGroupingRefs Expressions for a project list, which could contain - * [[AggregateExpression]]s and [[GroupingExprRef]]s. + * @param aggregateExpressions Expressions for a project list, which can contain + * [[AggregateExpression]]s and [[GroupingExprRef]]s. * @param child The child of the aggregate node. - * @param enforceGroupingReferences If [[aggrExprWithGroupingRefs]] should contain - * [[GroupingExprRef]]s. * - * [[aggrExprWithGroupingRefs]] without aggregate functions can contain [[GroupingExprRef]] - * expressions to refer to complex grouping expressions in [[groupingExpressions]]. These references - * ensure that optimization rules don't change the aggregate expressions to invalid ones that no - * longer refer to any grouping expressions and also simplify the expression transformations on the - * node (need to transform the expression only once). + * Expressions without aggregate functions in [[aggregateExpressions]] can contain + * [[GroupingExprRef]]s to refer to complex grouping expressions in [[groupingExpressions]]. These + * references ensure that optimization rules don't change the aggregate expressions to invalid ones + * that no longer refer to any grouping expressions and also simplify the expression transformations + * on the node (need to transform the expression only once). * * For example, in the following query Spark shouldn't optimize the aggregate expression * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`: @@ -644,27 +642,22 @@ case class Range( * GROUP BY c IS NULL * Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`. */ -case class AggregateWithGroupingReferences( +case class Aggregate( groupingExpressions: Seq[Expression], - aggrExprWithGroupingRefs: Seq[NamedExpression], - child: LogicalPlan, - enforceGroupingReferences: Boolean) + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode { - override val nodeName = "Aggregate" - - override val stringArgs = - Iterator(groupingExpressions, aggrExprWithGroupingRefs, child) override lazy val resolved: Boolean = { - val hasWindowExpressions = aggrExprWithGroupingRefs.exists(_.collect { - case window: WindowExpression => window - }.nonEmpty) + val hasWindowExpressions = aggregateExpressions.exists ( _.collect { + case window: WindowExpression => window + }.nonEmpty + ) !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - override def output: Seq[Attribute] = aggrExprWithGroupingRefs.map(_.toAttribute) - + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = { if (groupingExpressions.isEmpty) { Some(1L) @@ -682,12 +675,8 @@ case class AggregateWithGroupingReferences( } } - lazy val aggregateExpressions = { - if (enforceGroupingReferences) { - aggrExprWithGroupingRefs.map(expandGroupingReferences(_).asInstanceOf[NamedExpression]) - } else { - aggrExprWithGroupingRefs - } + lazy val aggregateExpressionsWithoutGroupingRefs = { + aggregateExpressions.map(expandGroupingReferences(_).asInstanceOf[NamedExpression]) } override lazy val validConstraints: ExpressionSet = { @@ -727,35 +716,23 @@ object Aggregate { aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression]) } - def apply( + def withGroupingRefs( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: LogicalPlan, - enforceGroupingReferences: Boolean = true): Aggregate = { - val (newGroupingExpressions, aggrExprWithGroupingReferences) = - if (enforceGroupingReferences) { - val dealiasedGroupingExpressions = groupingExpressions.map { - case a: Alias => a.child - case o => o - } - val complexGroupingExpressions = - collectComplexGroupingExpressions(dealiasedGroupingExpressions) - val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) { - insertGroupingReferences(aggregateExpressions, complexGroupingExpressions) - } else { - aggregateExpressions - } - (dealiasedGroupingExpressions, aggrExprWithGroupingReferences) - } else { - (groupingExpressions, aggregateExpressions) - } - new Aggregate(newGroupingExpressions, aggrExprWithGroupingReferences, child, - enforceGroupingReferences) - } + child: LogicalPlan): Aggregate = { + val dealiasedGroupingExpressions = groupingExpressions.map { + case a: Alias => a.child + case o => o + } + val complexGroupingExpressions = + collectComplexGroupingExpressions(dealiasedGroupingExpressions) + val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) { + insertGroupingReferences(aggregateExpressions, complexGroupingExpressions) + } else { + aggregateExpressions + } - def unapply( - aggregate: Aggregate): Option[(Seq[Expression], Seq[NamedExpression], LogicalPlan)] = { - Some(aggregate.groupingExpressions, aggregate.aggregateExpressions, aggregate.child) + new Aggregate(dealiasedGroupingExpressions, aggrExprWithGroupingReferences, child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala deleted file mode 100644 index 88b783b004ce..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/package.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -package object logical { - /** - * This Aggregate type alias together with Aggregate object keeps the semantics of Aggregate after - * the Aggregate node was renamed to AggregateWithGroupingReferences. - */ - type Aggregate = AggregateWithGroupingReferences -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index dcb7865f5a1a..81fc86420e1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -70,40 +70,40 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan = Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) val originalPlan2 = Aggregate( Seq(GroupingSets(Seq(), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan2, expected2) val originalPlan3 = Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b), Seq(unresolved_c)), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) } test("grouping sets with no explicit group by expressions") { val originalPlan = Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Nil)), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) // Computation of grouping expression should remove duplicate expression based on their @@ -112,7 +112,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil)), Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), - unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) val resultPlan = getAnalyzer.executeAndCheck(originalPlan2, new QueryPlanningTracker) val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions @@ -126,42 +126,40 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { test("cube") { val originalPlan = Aggregate(Seq(Cube(Seq(Seq(unresolved_a), Seq(unresolved_b)))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) - val originalPlan2 = - Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1, false) + val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, 0L)), Seq(a, b, c, gid), - Project(Seq(a, b, c), r1)), false) + Project(Seq(a, b, c), r1))) checkAnalysis(originalPlan2, expected2) } test("rollup") { val originalPlan = Aggregate(Seq(Rollup(Seq(Seq(unresolved_a), Seq(unresolved_b)))), - Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1, false) + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) - val originalPlan2 = - Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1, false) + val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), Expand( Seq(Seq(a, b, c, 0L)), Seq(a, b, c, gid), - Project(Seq(a, b, c), r1)), false) + Project(Seq(a, b, c), r1))) checkAnalysis(originalPlan2, expected2) } @@ -171,38 +169,38 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(Grouping(unresolved_a))), r1, false) + UnresolvedAlias(Grouping(unresolved_a))), r1) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) // Cube val originalPlan2 = Aggregate(Seq(Cube(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(Grouping(unresolved_a))), r1, false) + UnresolvedAlias(Grouping(unresolved_a))), r1) val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan2, expected2) // Rollup val originalPlan3 = Aggregate(Seq(Rollup(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(Grouping(unresolved_a))), r1, false) + UnresolvedAlias(Grouping(unresolved_a))), r1) val expected3 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan3, expected3) } @@ -212,38 +210,38 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1, false) + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), Expand( Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) // Cube val originalPlan2 = Aggregate(Seq(Cube(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1, false) + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, b, 2L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan2, expected2) // Rollup val originalPlan3 = Aggregate(Seq(Rollup(Seq(Seq(unresolved_a), Seq(unresolved_b)))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)), - UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1, false) + UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1) val expected3 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")), Expand( Seq(Seq(a, b, c, a, b, 0L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, nulInt, nulStr, 3L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan3, expected3) } @@ -253,7 +251,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b))), - Seq(unresolved_a, unresolved_b), r1, false)) + Seq(unresolved_a, unresolved_b), r1)) val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType, Option(TimeZone.getDefault().getID)) === 0, Aggregate(Seq(a, b, gid), @@ -262,11 +260,11 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) checkAnalysis(originalPlan, expected) val originalPlan2 = Filter(Grouping(unresolved_a) === 0, - Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1, false)) + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) assertAnalysisError(originalPlan2, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) @@ -274,7 +272,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1L, Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1, false)) + Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1)) val expected3 = Project(Seq(a, b), Filter(gid === 1L, Aggregate(Seq(a, b, gid), Seq(a, b, gid), @@ -282,11 +280,11 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) checkAnalysis(originalPlan3, expected3) val originalPlan4 = Filter(GroupingID(Seq(unresolved_a)) === 1, - Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1, false)) + Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1)) assertAnalysisError(originalPlan4, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) } @@ -297,7 +295,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, Aggregate( Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1, false)) + Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1)) val expected = Project(Seq(a, b), Sort( Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), @@ -306,12 +304,11 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) checkAnalysis(originalPlan, expected) val originalPlan2 = Sort(Seq(SortOrder(Grouping(unresolved_a), Ascending)), true, - Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1, - false)) + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) assertAnalysisError(originalPlan2, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) @@ -319,7 +316,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan3 = Sort( Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true, Aggregate(Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), - Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1, false)) + Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b), r1)) val expected3 = Project(Seq(a, b), Sort( Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), @@ -328,13 +325,12 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L), Seq(a, b, c, a, b, 0L)), Seq(a, b, c, a, b, gid), - Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)), false))) + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))))) checkAnalysis(originalPlan3, expected3) val originalPlan4 = Sort( Seq(SortOrder(GroupingID(Seq(unresolved_a)), Ascending)), true, - Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1, - false)) + Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1)) assertAnalysisError(originalPlan4, Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index a8aba39e389b..8984bad479a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -29,13 +29,10 @@ class AggregateOptimizeSuite extends AnalysisTest { val analyzer = getAnalyzer object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Finish Analysis", Once, - EnforceGroupingReferencesInAggregates) :: - Batch("Aggregate", FixedPoint(100), - FoldablePropagation, - RemoveLiteralFromGroupExpressions, - RemoveRepetitionFromGroupExpressions) :: Nil + val batches = Batch("Aggregate", FixedPoint(100), + FoldablePropagation, + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -53,8 +50,7 @@ class AggregateOptimizeSuite extends AnalysisTest { val analyzer = getAnalyzer val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = EnforceGroupingReferencesInAggregates( - analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))) + val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala index b5cbdd906ad8..82db174ad41b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala @@ -34,9 +34,7 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { val anotherTestRelation = LocalRelation('d.int, 'e.int) object Optimize extends RuleExecutor[LogicalPlan] { - val batches = { - Batch("Finish Analysis", Once, - EnforceGroupingReferencesInAggregates) :: + val batches = Batch("Default", FixedPoint(10), FoldablePropagation, LimitPushDown) :: @@ -44,7 +42,6 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { EliminateSorts) :: Batch("Collapse Project", Once, CollapseProject) :: Nil - } } def repartition(plan: LogicalPlan): LogicalPlan = plan.repartition(10) @@ -142,8 +139,7 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { val optimizedPlan = testRelation.distribute('a)(2).where('a === 10) val aggPlan = plan.groupBy('a)(sum('b)) val optimizedAggPlan = optimize(aggPlan) - val correctAggPlan = - EnforceGroupingReferencesInAggregates(analyze(optimizedPlan.groupBy('a)(sum('b)))) + val correctAggPlan = analyze(optimizedPlan.groupBy('a)(sum('b))) comparePlans(optimizedAggPlan, correctAggPlan) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala index d376c31ef965..3eba003d7752 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala @@ -96,7 +96,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy('a + 'b)(('a + 'b) as 'c) .analyze val optimized = Optimize.execute(query) - comparePlans(optimized, expected) + comparePlans(optimized, EnforceGroupingReferencesInAggregates(expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 49cd3acd8242..d14996709401 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -36,6 +36,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = + Batch("Finish Analysis", Once, + EnforceGroupingReferencesInAggregates) :: Batch("collapse projections", FixedPoint(10), CollapseProject) :: Batch("Constant Folding", FixedPoint(10), @@ -57,7 +59,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { val optimized = Optimizer.execute(originalQuery.analyze) assert(optimized.resolved, "optimized plans must be still resolvable") - comparePlans(optimized, correctAnswer.analyze) + comparePlans(optimized, EnforceGroupingReferencesInAggregates(correctAnswer.analyze)) } test("explicit get from namedStruct") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 01850ead92c6..8cd8ec8bd801 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -339,11 +339,11 @@ class PlanParserSuite extends AnalysisTest { // Grouping Sets assertEqual(s"$sql grouping sets((a, b), (a), ())", Aggregate(Seq(GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b))), - Seq('a, 'b, 'sum.function('c).as("c")), table("d"), false)) + Seq('a, 'b, 'sum.function('c).as("c")), table("d"))) assertEqual(s"$sqlWithoutGroupBy group by grouping sets((a, b), (a), ())", Aggregate(Seq(GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b))), - Seq('a, 'b, 'sum.function('c).as("c")), table("d"), false)) + Seq('a, 'b, 'sum.function('c).as("c")), table("d"))) val m = intercept[ParseException] { parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index ad98fc757a7c..76ee297dfca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -446,7 +446,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( } else { Alias(CreateStruct(groupingAttributes), "key")() } - val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan, false) + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c8c964e928ef..bf3f7058b9bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -66,15 +66,15 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => - Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan, false)) + Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( df.sparkSession, Aggregate(Seq(Rollup(groupingExprs.map(Seq(_)))), - aliasedAgg, df.logicalPlan, false)) + aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))), - aliasedAgg, df.logicalPlan, false)) + aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index d68c5f03b7c4..f4637fe4b3f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -40,6 +40,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(e) || + e.isInstanceOf[GroupingExprRef] || agg.groupingExpressions.exists(_.semanticEquals(e)) } @@ -72,7 +73,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } } // There is no Python UDF over aggregate expression - Project(projList.toSeq, Aggregate(agg.groupingExpressions, aggExpr.toSeq, agg.child)) + Project(projList.toSeq, agg.copy(aggregateExpressions = aggExpr.toSeq)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -119,21 +120,9 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { groupingExpr += expr } } - val aggExpr = agg.aggregateExpressions.map { expr => - expr.transformUp { - // PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate. - // PythonUDF here should be either - // 1. Argument of an aggregate function. - // CheckAnalysis guarantees the arguments are deterministic. - // 2. PythonUDF in grouping key. Grouping key must be deterministic. - // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key - // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. - case p: PythonUDF if p.udfDeterministic => - val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] - attributeMap.getOrElse(canonicalized, p) - }.asInstanceOf[NamedExpression] - } - Aggregate(groupingExpr.toSeq, aggExpr, Project((projList ++ agg.child.output).toSeq, agg.child)) + agg.copy( + groupingExpressions = groupingExpr.toSeq, + child = Project((projList ++ agg.child.output).toSeq, agg.child)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { From 06224449062e2f1e97a419adf44de45dce9eeb6d Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 12 Apr 2021 18:44:39 +0200 Subject: [PATCH 10/15] minor fixes --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 2 +- .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 77f0c719e426..c3d2f336f06f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -795,7 +795,7 @@ object NullPropagation extends Rule[LogicalPlan] { */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - EnforceGroupingReferencesInAggregates(CleanupAliases(propagateFoldables(plan)._1)) + CleanupAliases(propagateFoldables(plan)._1) } private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d07872537667..a4b77a5cc493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -690,8 +690,7 @@ object Aggregate { groupingExpressions.zipWithIndex .foldLeft(mutable.Map.empty[Expression, (Expression, Int)]) { case (m, (ge, i)) => - if (ge.deterministic && !ge.foldable && ge.children.nonEmpty && - !m.contains(ge.canonicalized)) { + if (!ge.foldable && ge.children.nonEmpty && !m.contains(ge.canonicalized)) { m += ge.canonicalized -> (ge, i) } m From 2e79eb910f5d4ef60227f7352e86dbed2786cb7a Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 13 Apr 2021 19:47:42 +0200 Subject: [PATCH 11/15] review fixes --- .../catalyst/analysis/UpdateNullability.scala | 2 +- .../expressions/aggregate/interfaces.scala | 8 +++++ .../sql/catalyst/optimizer/Optimizer.scala | 10 ++---- .../plans/logical/basicLogicalOperators.scala | 31 ++++++++----------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala index d9a52f3ccf2c..adc696618d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala @@ -58,7 +58,7 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] { * referenced grouping expression. */ object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case a: Aggregate => val nullabilities = a.groupingExpressions.map(_.nullable).toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 281734c6f14a..8c70c86aa186 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -80,6 +80,14 @@ object AggregateExpression { filter, NamedExpression.newExprId) } + + def containsAggregate(expr: Expression): Boolean = { + expr.find(isAggregate).isDefined + } + + def isAggregate(expr: Expression): Boolean = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f794f2152d04..5e8f0da6373c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -524,23 +524,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { } private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = { - val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate) + val upperHasNoAggregateExpressions = + !upper.aggregateExpressions.exists(AggregateExpression.containsAggregate) lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet( lower .aggregateExpressions .filter(_.deterministic) - .filter(!isAggregate(_)) + .filterNot(AggregateExpression.containsAggregate) .map(_.toAttribute) )) upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg } - - private def isAggregate(expr: Expression): Boolean = { - expr.find(e => e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupedAggPandasUDF(e)).isDefined - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7edc97380310..d04e8cbc8cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -819,14 +819,16 @@ case class Aggregate( object Aggregate { private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = { - groupingExpressions.zipWithIndex - .foldLeft(mutable.Map.empty[Expression, (Expression, Int)]) { - case (m, (ge, i)) => - if (!ge.foldable && ge.children.nonEmpty && !m.contains(ge.canonicalized)) { - m += ge.canonicalized -> (ge, i) - } - m - } + val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)] + var i = 0 + groupingExpressions.foreach { + case ge if !ge.foldable && ge.children.nonEmpty && + !complexGroupingExpressions.contains(ge.canonicalized) => + complexGroupingExpressions += ge.canonicalized -> (ge, i) + i += 1 + case _ => + } + complexGroupingExpressions } private def insertGroupingReferences( @@ -834,9 +836,7 @@ object Aggregate { groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = { def insertGroupingExprRefs(e: Expression): Expression = { e match { - case _ if !e.deterministic => e - case _: AggregateExpression => e - case _ if PythonUDF.isGroupedAggPandasUDF(e) => e + case _ if AggregateExpression.isAggregate(e) => e case _ if groupingExpressions.contains(e.canonicalized) => val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized) GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable) @@ -851,19 +851,14 @@ object Aggregate { groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: LogicalPlan): Aggregate = { - val dealiasedGroupingExpressions = groupingExpressions.map { - case a: Alias => a.child - case o => o - } - val complexGroupingExpressions = - collectComplexGroupingExpressions(dealiasedGroupingExpressions) + val complexGroupingExpressions = collectComplexGroupingExpressions(groupingExpressions) val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) { insertGroupingReferences(aggregateExpressions, complexGroupingExpressions) } else { aggregateExpressions } - new Aggregate(dealiasedGroupingExpressions, aggrExprWithGroupingReferences, child) + new Aggregate(groupingExpressions, aggrExprWithGroupingReferences, child) } } From cff9b9ab289412529728749001ded9550f14e66e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 14 Apr 2021 13:39:26 +0200 Subject: [PATCH 12/15] fix latest test failures, add new test case --- .../plans/logical/basicLogicalOperators.scala | 13 ++++++------- .../test/resources/sql-tests/inputs/group-by.sql | 6 +++++- .../resources/sql-tests/results/group-by.sql.out | 16 +++++++++++++++- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d04e8cbc8cc7..04dc0d0592a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -797,8 +797,7 @@ case class Aggregate( private def expandGroupingReferences(e: Expression): Expression = { e match { - case _: AggregateExpression => e - case _ if PythonUDF.isGroupedAggPandasUDF(e) => e + case _ if AggregateExpression.isAggregate(e) => e case g: GroupingExprRef => groupingExpressions(g.ordinal) case _ => e.mapChildren(expandGroupingReferences) } @@ -821,12 +820,12 @@ object Aggregate { private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = { val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)] var i = 0 - groupingExpressions.foreach { - case ge if !ge.foldable && ge.children.nonEmpty && - !complexGroupingExpressions.contains(ge.canonicalized) => + groupingExpressions.foreach { ge => + if (!ge.foldable && ge.children.nonEmpty && + !complexGroupingExpressions.contains(ge.canonicalized)) { complexGroupingExpressions += ge.canonicalized -> (ge, i) - i += 1 - case _ => + } + i += 1 } complexGroupingExpressions } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 8f03e08f3cbc..c18a92a38355 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -183,4 +183,8 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 -- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function SELECT not(a IS NULL), count(*) AS c FROM testData -GROUP BY a IS NULL +GROUP BY a IS NULL; + +SELECT a + b + rand(0), count(*) AS c +FROM testData +GROUP BY a + b; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index a4b2070fc4ca..01bd876faf19 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 63 +-- Number of queries: 64 -- !query @@ -653,3 +653,17 @@ struct<(NOT (a IS NULL)):boolean,c:bigint> -- !query output false 2 true 7 + + +-- !query +SELECT a + b + rand(0), count(*) AS c +FROM testData +GROUP BY a + b +-- !query schema +struct<((a + b) + rand(0)):double,c:bigint> +-- !query output +2.7604953758285915 1 +3.6363787615254752 2 +4.523419425688557 2 +5.095347282642472 1 +NULL 3 From 78296a86585e9f0cad5c27fa534aa614a710793e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 14 Apr 2021 17:47:16 +0200 Subject: [PATCH 13/15] better non-deterministic test case --- .../test/resources/sql-tests/inputs/group-by.sql | 4 ++-- .../resources/sql-tests/results/group-by.sql.out | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c18a92a38355..988ad99418a1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -185,6 +185,6 @@ SELECT not(a IS NULL), count(*) AS c FROM testData GROUP BY a IS NULL; -SELECT a + b + rand(0), count(*) AS c +SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c FROM testData -GROUP BY a + b; +GROUP BY a IS NULL; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 01bd876faf19..b5471a785a22 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -656,14 +656,11 @@ true 7 -- !query -SELECT a + b + rand(0), count(*) AS c +SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c FROM testData -GROUP BY a + b +GROUP BY a IS NULL -- !query schema -struct<((a + b) + rand(0)):double,c:bigint> +struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint> -- !query output -2.7604953758285915 1 -3.6363787615254752 2 -4.523419425688557 2 -5.095347282642472 1 -NULL 3 +0.7604953758285915 7 +1.0 2 From 72c173beebaf5ec63a9fd9360410b355462943d1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 15 Apr 2021 13:36:03 +0200 Subject: [PATCH 14/15] make new rules non excludable --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5e8f0da6373c..2f1020535ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -268,7 +268,9 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceUpdateFieldsExpression.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: + EnforceGroupingReferencesInAggregates.ruleName :: + UpdateGroupingExprRefNullability.ruleName :: Nil /** * Optimize all the subqueries inside expression. From fb3a19dad5ac4448b8fe9d1b12ed1c3f6a0369a7 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 17 Apr 2021 18:23:21 +0200 Subject: [PATCH 15/15] fix validConstraints, minor changes --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 8 +++----- .../catalyst/plans/logical/basicLogicalOperators.scala | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9916afe1ac43..a96674fe9705 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -297,11 +297,9 @@ object PhysicalAggregation { val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. - case agg: AggregateExpression - if !equivalentAggregateExpressions.addExpr(agg) => agg - case udf: PythonUDF - if PythonUDF.isGroupedAggPandasUDF(udf) && - !equivalentAggregateExpressions.addExpr(udf) => udf + case a + if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) => + a } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ef6d1d8daf1..21e87b4c6260 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -837,7 +837,8 @@ case class Aggregate( } override lazy val validConstraints: ExpressionSet = { - val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) + val nonAgg = aggregateExpressionsWithoutGroupingRefs. + filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) getAllValidConstraints(nonAgg) }