From ec50ab6987be596490912944644a3527db1b28b6 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 15 May 2020 15:36:28 +0000 Subject: [PATCH] [SPARK-31620][SQL] Fix reference binding failure in case of an final agg contains subquery Instead of using `child.output` directly, we should use `inputAggBufferAttributes` from the current agg expression for `Final` and `PartialMerge` aggregates to bind references for their `mergeExpression`. When planning aggregates, the partial aggregate uses agg fucs' `inputAggBufferAttributes` as its output, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala#L105 For final `HashAggregateExec`, we need to bind the `DeclarativeAggregate.mergeExpressions` with the output of the partial aggregate operator, see https://github.com/apache/spark/blob/v3.0.0-rc1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L348 This is usually fine. However, if we copy the agg func somehow after agg planning, like `PlanSubqueries`, the `DeclarativeAggregate` will be replaced by a new instance with new `inputAggBufferAttributes` and `mergeExpressions`. Then we can't bind the `mergeExpressions` with the output of the partial aggregate operator, as it uses the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Note that, `ImperativeAggregate` doesn't have this problem, as we don't need to bind its `mergeExpressions`. It has a different mechanism to access buffer values, via `mutableAggBufferOffset` and `inputAggBufferOffset`. Yes, user hit error previously but run query successfully after this change. Added a regression test. Closes #28496 from Ngone51/spark-31620. Authored-by: yi.wu Signed-off-by: Wenchen Fan --- .../aggregate/BaseAggregateExec.scala | 56 +++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 8 +-- .../aggregate/ObjectHashAggregateExec.scala | 4 +- .../aggregate/SortAggregateExec.scala | 6 +- .../spark/sql/DataFrameAggregateSuite.scala | 39 +++++++++++++ 5 files changed, 104 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala new file mode 100644 index 0000000000000..3ee35777933ac --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * Holds common logic for aggregate operators + */ +trait BaseAggregateExec extends UnaryExecNode { + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def resultExpressions: Seq[NamedExpression] + + protected def inputAttributes: Seq[Attribute] = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references for `DeclarativeAggregate.mergeExpressions`. But if we copy the + // aggregate function somehow after aggregate planning, like `PlanSubqueries`, the + // `DeclarativeAggregate` will be replaced by a new instance with new + // `inputAggBufferAttributes` and `mergeExpressions`. Then Final and PartialMerge aggregate + // can't bind the `mergeExpressions` with the output of the partial aggregate, as they use + // the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead, + // we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = aggregateExpressions + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction) + .flatMap(_.inputAggBufferAttributes) + child.output.dropRight(aggAttrs.length) ++ aggAttrs + } else { + child.output + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 19a47ffc6dd03..617d69bfa75ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -47,7 +47,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen { + extends BaseAggregateExec with BlockingOperatorWithCodegen { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -121,7 +121,7 @@ case class HashAggregateExec( resultExpressions, (expressions, inputSchema) => newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), - child.output, + inputAttributes, iter, testFallbackStartsAt, numOutputRows, @@ -254,7 +254,7 @@ case class HashAggregateExec( private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes val updateExpr = aggregateExpressions.flatMap { e => e.mode match { case Partial | Complete => @@ -817,7 +817,7 @@ case class HashAggregateExec( } } - val inputAttr = aggregateBufferAttributes ++ child.output + val inputAttr = aggregateBufferAttributes ++ inputAttributes // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 5b340eead39e6..10b9f17f6d82e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -65,7 +65,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends BaseAggregateExec { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -121,7 +121,7 @@ case class ObjectHashAggregateExec( resultExpressions, (expressions, inputSchema) => newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), - child.output, + inputAttributes, iter, fallbackCountThreshold, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 7ab6ecc08a7bc..be4bdc355ad6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends BaseAggregateExec { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -86,7 +86,7 @@ case class SortAggregateExec( val outputIter = new SortBasedAggregationIterator( partIndex, groupingExpressions, - child.output, + inputAttributes, iter, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 73259a0ed3b50..418664df058d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -772,4 +772,43 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2) ) } + + Seq(true, false).foreach { value => + test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) { + withTempView("t1", "t2") { + sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") + sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + + // test without grouping keys + checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), + Row(4) :: Nil) + + // test with grouping keys + checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + + "t2 group by c"), Row(3, 4) :: Nil) + + // test with distinct + checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil) + + // test subquery with agg + checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) + + // test SortAggregateExec + var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: SortAggregateExec => true }.isDefined) + checkAnswer(df, Row("str1") :: Nil) + + // test ObjectHashAggregateExec + df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: ObjectHashAggregateExec => true }.isDefined) + checkAnswer(df, Row(Array(4), 4) :: Nil) + } + } + } + } }