From 4c40c9b4eb37f88ef740aa5d21f1c3c68e5b4ab7 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 22 Apr 2022 19:51:32 +0800 Subject: [PATCH 1/9] [SPARK-38997][SQL] DS V2 aggregate push-down supports group by expressions --- .../expressions/aggregate/Aggregation.java | 10 ++--- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/AggregatePushDownUtils.scala | 20 +++++++--- .../datasources/DataSourceStrategy.scala | 7 ++-- .../execution/datasources/orc/OrcUtils.scala | 2 +- .../datasources/parquet/ParquetUtils.scala | 2 +- .../v2/V2ScanRelationPushDown.scala | 19 ++++++++-- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 23 +++++------ .../datasources/v2/orc/OrcScan.scala | 2 +- .../datasources/v2/parquet/ParquetScan.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 38 +++++++++++++++++-- 11 files changed, 88 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java index cf7dbb2978dd7..11d9e475ca1bf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * Aggregation in SQL statement. @@ -30,14 +30,14 @@ @Evolving public final class Aggregation implements Serializable { private final AggregateFunc[] aggregateExpressions; - private final NamedReference[] groupByColumns; + private final Expression[] groupByExpressions; - public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) { + public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] groupByExpressions) { this.aggregateExpressions = aggregateExpressions; - this.groupByColumns = groupByColumns; + this.groupByExpressions = groupByExpressions; } public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; } - public NamedReference[] groupByColumns() { return groupByColumns; } + public Expression[] groupByExpressions() { return groupByExpressions; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 5cf8aa91ea5cd..edb72c35b5117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -163,7 +163,7 @@ case class RowDataSourceScanExec( "PushedFilters" -> pushedFilters) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), - "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++ + "PushedGroupByColumns" -> seqToString(v.groupByExpressions.map(_.describe())))} ++ topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 4779a3eaf2531..08f4395228129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils @@ -93,8 +94,8 @@ object AggregatePushDownUtils { return None } - if (aggregation.groupByColumns.nonEmpty && - partitionNames.size != aggregation.groupByColumns.length) { + if (aggregation.groupByExpressions.nonEmpty && + partitionNames.size != aggregation.groupByExpressions.length) { // If there are group by columns, we only push down if the group by columns are the same as // the partition columns. In theory, if group by columns are a subset of partition columns, // we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3, @@ -106,7 +107,9 @@ object AggregatePushDownUtils { // aggregate push down simple and don't handle this complicate case for now. return None } - aggregation.groupByColumns.foreach { col => + aggregation.groupByExpressions.foreach { expr => + assert(expr.isInstanceOf[FieldReference]) + val col = expr.asInstanceOf[FieldReference] // don't push down if the group by columns are not the same as the partition columns (orders // doesn't matter because reorder can be done at data source layer) if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None @@ -137,7 +140,8 @@ object AggregatePushDownUtils { def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { a.aggregateExpressions.sortBy(_.hashCode()) .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && - a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) + a.groupByExpressions.sortBy(_.hashCode()) + .sameElements(b.groupByExpressions.sortBy(_.hashCode())) } /** @@ -164,7 +168,7 @@ object AggregatePushDownUtils { def getSchemaWithoutGroupingExpression( aggSchema: StructType, aggregation: Aggregation): StructType = { - val numOfGroupByColumns = aggregation.groupByColumns.length + val numOfGroupByColumns = aggregation.groupByExpressions.length if (numOfGroupByColumns > 0) { new StructType(aggSchema.fields.drop(numOfGroupByColumns)) } else { @@ -179,7 +183,11 @@ object AggregatePushDownUtils { partitionSchema: StructType, aggregation: Aggregation, partitionValues: InternalRow): InternalRow = { - val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head) + val groupByColNames = + aggregation.groupByExpressions.map { expr => + assert(expr.isInstanceOf[FieldReference]) + expr.asInstanceOf[FieldReference].fieldNames.head + } assert(groupByColNames.length == partitionSchema.length && groupByColNames.length == partitionValues.numFields, "The number of group by columns " + s"${groupByColNames.length} should be the same as partition schema length " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1b14884e75994..e35d09320760c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -759,14 +759,13 @@ object DataSourceStrategy protected[sql] def translateAggregation( aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { - def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference.column(name).asInstanceOf[FieldReference]) + def translateGroupBy(e: Expression): Option[V2Expression] = e match { + case PushableExpression(expr) => Some(expr) case _ => None } val translatedAggregates = aggregates.flatMap(translateAggregate) - val translatedGroupBys = groupBy.flatMap(columnAsString) + val translatedGroupBys = groupBy.flatMap(translateGroupBy) if (translatedAggregates.length != aggregates.length || translatedGroupBys.length != groupBy.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 79abdfe46909d..9d98122905e7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -516,7 +516,7 @@ object OrcUtils extends Logging { val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy, (0 until schemaWithoutGroupBy.length).toArray) val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues) - if (aggregation.groupByColumns.nonEmpty) { + if (aggregation.groupByExpressions.nonEmpty) { val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( partitionSchema, aggregation, partitionValues) new JoinedRow(reOrderedPartitionValues, resultRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 9f2e6580ecb46..afd7006bec0ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -277,7 +277,7 @@ object ParquetUtils { throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) } - if (aggregation.groupByColumns.nonEmpty) { + if (aggregation.groupByExpressions.nonEmpty) { val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( partitionSchema, aggregation, partitionValues) new JoinedRow(reorderedPartitionValues, converter.currentRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 6455e25089276..211969526d54b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -184,9 +184,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit // scalastyle:on val newOutput = scan.readSchema().toAttributes assert(newOutput.length == groupingExpressions.length + finalAggregates.length) + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + var ordinal = 0 val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) - case (_, b) => b + case (expr, b) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + ordinal += 1 + } + b } val aggOutput = newOutput.drop(groupAttrs.length) val output = groupAttrs ++ aggOutput @@ -197,7 +204,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit |Pushed Aggregate Functions: | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} |Pushed Group by: - | ${pushedAggregates.get.groupByColumns.mkString(", ")} + | ${pushedAggregates.get.groupByExpressions.mkString(", ")} |Output: ${output.mkString(", ")} """.stripMargin) @@ -206,14 +213,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) if (r.supportCompletePushDown(pushedAggregates.get)) { val projectExpressions = finalResultExpressions.map { expr => - // TODO At present, only push down group by attribute is supported. - // In future, more attribute conversion is extended here. e.g. GetStructField expr.transform { case agg: AggregateExpression => val ordinal = aggExprToOutputOrdinal(agg.canonicalized) val child = addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + addCastIfNeeded(groupAttrs(ordinal), expr.dataType) } }.asInstanceOf[Seq[NamedExpression]] Project(projectExpressions, scanRelation) @@ -256,6 +264,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case other => other } agg.copy(aggregateFunction = aggFunction) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + addCastIfNeeded(groupAttrs(ordinal), expr.dataType) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 0a1542a42956d..630afc838a779 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,7 +20,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} @@ -73,9 +73,12 @@ case class JDBCScanBuilder( private var pushedGroupByCols: Option[Array[String]] = None override def supportCompletePushDown(aggregation: Aggregation): Boolean = { - lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames() + lazy val fieldNames = aggregation.groupByExpressions()(0) match { + case field: FieldReference => field.fieldNames + case _ => Array.empty[String] + } jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) || - (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 && + (aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 && jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_))) } @@ -86,20 +89,18 @@ case class JDBCScanBuilder( val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate) if (compiledAggs.length != aggregation.aggregateExpressions.length) return false - val groupByCols = aggregation.groupByColumns.map { col => - if (col.fieldNames.length != 1) return false - dialect.quoteIdentifier(col.fieldNames.head) - } + val compiledGroupBys = aggregation.groupByExpressions.flatMap(dialect.compileExpression) + if (compiledGroupBys.length != aggregation.groupByExpressions.length) return false // The column names here are already quoted and can be used to build sql string directly. // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" // GROUP BY "DEPT", "NAME" - val selectList = groupByCols ++ compiledAggs - val groupByClause = if (groupByCols.isEmpty) { + val selectList = compiledGroupBys ++ compiledAggs + val groupByClause = if (compiledGroupBys.isEmpty) { "" } else { - "GROUP BY " + groupByCols.mkString(",") + "GROUP BY " + compiledGroupBys.mkString(",") } val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " + @@ -107,7 +108,7 @@ case class JDBCScanBuilder( try { finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect) pushedAggregateList = selectList - pushedGroupByCols = Some(groupByCols) + pushedGroupByCols = Some(compiledGroupBys) true } catch { case NonFatal(e) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index ad8857d98037c..ccb9ca9c6b3f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -87,7 +87,7 @@ case class OrcScan( lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { (seqToString(pushedAggregate.get.aggregateExpressions), - seqToString(pushedAggregate.get.groupByColumns)) + seqToString(pushedAggregate.get.groupByExpressions)) } else { ("[]", "[]") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 6b35f2406a82f..99632d79cd8dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -116,7 +116,7 @@ case class ParquetScan( lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { (seqToString(pushedAggregate.get.aggregateExpressions), - seqToString(pushedAggregate.get.groupByColumns)) + seqToString(pushedAggregate.get.groupByExpressions)) } else { ("[]", "[]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5cfa2f465a2be..237da42b4820f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -736,11 +736,41 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: SUM with group by") { - val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)], " + + val df1 = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [], PushedGroupByColumns: [DEPT], ") - checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) + checkAnswer(df1, Seq(Row(19000), Row(22000), Row(12000))) + + val df2 = sql( + """ + |SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as key, + | SUM(SALARY) FROM h2.test.employee GROUP BY key""".stripMargin) + checkAggregateRemoved(df2) + checkPushedInfo(df2, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByColumns: [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df2, Seq(Row(0, 44000), Row(9000, 9000))) + + val df3 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy(when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0).as("key")) + .agg(sum($"SALARY")) + checkAggregateRemoved(df3, false) + checkPushedInfo(df3, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByColumns: [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df3, Seq(Row(0, 44000), Row(9000, 9000))) } test("scan with aggregate push-down: DISTINCT SUM with group by") { From 2440505cdf7d8112555da6a1209eff78c23dfc1e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 23 Apr 2022 08:50:49 +0800 Subject: [PATCH 2/9] Update code --- .../test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 237da42b4820f..a41040f1561d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -751,7 +751,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """ |PushedAggregates: [SUM(SALARY)], |PushedFilters: [], - |PushedGroupByColumns: [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |PushedGroupByColumns: + |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df2, Seq(Row(0, 44000), Row(9000, 9000))) @@ -768,7 +769,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """ |PushedAggregates: [SUM(SALARY)], |PushedFilters: [], - |PushedGroupByColumns: [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |PushedGroupByColumns: + |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df3, Seq(Row(0, 44000), Row(9000, 9000))) } From b7ebfc52e56965e00265e55d30f0199513c91aa5 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 26 Apr 2022 10:14:45 +0800 Subject: [PATCH 3/9] Update code --- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/AggregatePushDownUtils.scala | 2 +- .../v2/V2ScanRelationPushDown.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 50 +++++++++---------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index edb72c35b5117..292248485253e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -163,7 +163,7 @@ case class RowDataSourceScanExec( "PushedFilters" -> pushedFilters) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), - "PushedGroupByColumns" -> seqToString(v.groupByExpressions.map(_.describe())))} ++ + "PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++ topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 08f4395228129..6d7972e9af132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -108,7 +108,7 @@ object AggregatePushDownUtils { return None } aggregation.groupByExpressions.foreach { expr => - assert(expr.isInstanceOf[FieldReference]) + if (!expr.isInstanceOf[FieldReference]) return None val col = expr.asInstanceOf[FieldReference] // don't push down if the group by columns are not the same as the partition columns (orders // doesn't matter because reorder can be done at data source layer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 211969526d54b..63f81eeaaa473 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -213,7 +213,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) if (r.supportCompletePushDown(pushedAggregates.get)) { val projectExpressions = finalResultExpressions.map { expr => - expr.transform { + expr.transformDown { case agg: AggregateExpression => val ordinal = aggExprToOutputOrdinal(agg.canonicalized) val child = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a41040f1561d2..d66bc95349d51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -187,7 +187,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df4, false) checkPushedInfo(df4, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -279,7 +279,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df6, false) checkLimitRemoved(df6, false) checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + - " PushedFilters: [], PushedGroupByColumns: [DEPT], ") + " PushedFilters: [], PushedGroupByExpressions: [DEPT], ") checkAnswer(df6, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -633,7 +633,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], ") + "PushedGroupByExpressions: [DEPT], ") checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } @@ -654,7 +654,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " + "PushedFilters: [ID IS NOT NULL, ID > 0], " + - "PushedGroupByColumns: [], ") + "PushedGroupByExpressions: [], ") checkAnswer(df, Seq(Row(2, 1.5))) } @@ -739,7 +739,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df1 = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df1) checkPushedInfo(df1, "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [], PushedGroupByColumns: [DEPT], ") + "PushedFilters: [], PushedGroupByExpressions: [DEPT], ") checkAnswer(df1, Seq(Row(19000), Row(22000), Row(12000))) val df2 = sql( @@ -751,7 +751,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """ |PushedAggregates: [SUM(SALARY)], |PushedFilters: [], - |PushedGroupByColumns: + |PushedGroupByExpressions: |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df2, Seq(Row(0, 44000), Row(9000, 9000))) @@ -769,7 +769,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """ |PushedAggregates: [SUM(SALARY)], |PushedFilters: [], - |PushedGroupByColumns: + |PushedGroupByExpressions: |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df3, Seq(Row(0, 44000), Row(9000, 9000))) @@ -779,7 +779,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " + - "PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -789,7 +789,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), Row(10000, 1000), Row(12000, 1200))) } @@ -803,7 +803,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(filters1.isEmpty) checkAggregateRemoved(df1) checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), Row("2#david", 10000), Row("6#jen", 12000))) @@ -815,7 +815,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(filters2.isEmpty) checkAggregateRemoved(df2) checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -835,7 +835,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df, false) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -845,7 +845,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .min("SALARY").as("total") checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " + - "PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -860,7 +860,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(query, false)// filter over aggregate not pushed down checkAggregateRemoved(query) checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -892,7 +892,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -902,7 +902,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) } @@ -912,7 +912,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -922,7 +922,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } @@ -934,7 +934,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df2.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => val expectedPlanFragment = - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: []" checkKeywordsExistsInExplain(df2, expectedPlanFragment) relation.scan match { case v1: V1ScanWrapper => @@ -987,7 +987,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" + " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT], ") + "PushedGroupByExpressions: [DEPT], ") checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 0d, 0d, 3, 0d), Row(2, 2, 2, 2, 2, 10000d, 9000d, 10000d, 10000d, 9000d, 0d, 2, 0d))) @@ -1001,7 +1001,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expectedPlanFragment = if (ansiMode) { "PushedAggregates: [SUM(2147483647 + DEPT)], " + "PushedFilters: [], " + - "PushedGroupByColumns: []" + "PushedGroupByExpressions: []" } else { "PushedFilters: []" } @@ -1150,7 +1150,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df) checkPushedInfo(df, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) val df2 = spark.table("h2.test.employee") @@ -1160,7 +1160,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df2) checkPushedInfo(df2, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) } @@ -1177,7 +1177,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df, false) checkPushedInfo(df, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [NAME]") checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) @@ -1193,7 +1193,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df2, false) checkPushedInfo(df2, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [NAME]") checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) } From e53066e0a9be85de6388bd760d5dbf3a02de3147 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 26 Apr 2022 14:40:41 +0800 Subject: [PATCH 4/9] Update code --- .../datasources/AggregatePushDownUtils.scala | 21 ++++++------ .../v2/V2ScanRelationPushDown.scala | 13 ++++--- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 34 +++++++++++++++++++ 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 6d7972e9af132..11d6a2b426272 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils @@ -107,13 +107,11 @@ object AggregatePushDownUtils { // aggregate push down simple and don't handle this complicate case for now. return None } - aggregation.groupByExpressions.foreach { expr => - if (!expr.isInstanceOf[FieldReference]) return None - val col = expr.asInstanceOf[FieldReference] + aggregation.groupByExpressions.flatMap(extractColName).foreach { fieldName => // don't push down if the group by columns are not the same as the partition columns (orders // doesn't matter because reorder can be done at data source layer) - if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None - finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head)) + if (!isPartitionCol(fieldName)) return None + finalSchema = finalSchema.add(getStructFieldForCol(fieldName)) } aggregation.aggregateExpressions.foreach { @@ -183,11 +181,7 @@ object AggregatePushDownUtils { partitionSchema: StructType, aggregation: Aggregation, partitionValues: InternalRow): InternalRow = { - val groupByColNames = - aggregation.groupByExpressions.map { expr => - assert(expr.isInstanceOf[FieldReference]) - expr.asInstanceOf[FieldReference].fieldNames.head - } + val groupByColNames = aggregation.groupByExpressions.flatMap(extractColName) assert(groupByColNames.length == partitionSchema.length && groupByColNames.length == partitionValues.numFields, "The number of group by columns " + s"${groupByColNames.length} should be the same as partition schema length " + @@ -205,4 +199,9 @@ object AggregatePushDownUtils { partitionValues } } + + private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr match { + case f: FieldReference if f.fieldNames.length == 1 => Some(f.fieldNames.head) + case _ => None + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 63f81eeaaa473..51bb6f2b02b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -186,12 +186,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit assert(newOutput.length == groupingExpressions.length + finalAggregates.length) val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] var ordinal = 0 - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { - case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) - case (expr, b) => + val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { + case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) + case ((expr, b), idx) => if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { - groupByExprToOutputOrdinal(expr.canonicalized) = ordinal - ordinal += 1 + groupByExprToOutputOrdinal(expr.canonicalized) = idx } b } @@ -220,8 +219,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) + val idx = groupByExprToOutputOrdinal(expr.canonicalized) + addCastIfNeeded(groupAttrs(idx), expr.dataType) } }.asInstanceOf[Seq[NamedExpression]] Project(projectExpressions, scanRelation) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index d66bc95349d51..74e226acb7a14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -773,6 +773,40 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df3, Seq(Row(0, 44000), Row(9000, 9000))) + + val df4 = sql( + """ + |SELECT DEPT, CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as key, + | SUM(SALARY) FROM h2.test.employee GROUP BY DEPT, key""".stripMargin) + checkAggregateRemoved(df4) + checkPushedInfo(df4, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByExpressions: + |[DEPT, CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df4, Seq(Row(1, 0, 10000), Row(1, 9000, 9000), Row(2, 0, 22000), Row(6, 0, 12000))) + + val df5 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"DEPT", + when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0) + .as("key")) + .agg(sum($"SALARY")) + checkAggregateRemoved(df5, false) + checkPushedInfo(df5, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByExpressions: + |[DEPT, CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df5, Seq(Row(1, 0, 10000), Row(1, 9000, 9000), Row(2, 0, 22000), Row(6, 0, 12000))) } test("scan with aggregate push-down: DISTINCT SUM with group by") { From baa96df66c6a798b6e88e138c8a4f867b04d5771 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 26 Apr 2022 14:50:08 +0800 Subject: [PATCH 5/9] Update code --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 51bb6f2b02b58..ed901d72f47f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -185,7 +185,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val newOutput = scan.readSchema().toAttributes assert(newOutput.length == groupingExpressions.length + finalAggregates.length) val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - var ordinal = 0 val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) case ((expr, b), idx) => From 410158b4e6534bb3095d9804e3fcae8af0863fb5 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 26 Apr 2022 14:51:24 +0800 Subject: [PATCH 6/9] Update code --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ed901d72f47f8..e90e5b307d1f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -187,11 +187,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) - case ((expr, b), idx) => + case ((expr, attr), idx) => if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { groupByExprToOutputOrdinal(expr.canonicalized) = idx } - b + attr } val aggOutput = newOutput.drop(groupAttrs.length) val output = groupAttrs ++ aggOutput From 7b48b5436a6f19084f6406d97db680b2b7f7fb18 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 26 Apr 2022 14:52:55 +0800 Subject: [PATCH 7/9] Update code --- .../execution/datasources/v2/V2ScanRelationPushDown.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index e90e5b307d1f0..1ff2dfb585044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -187,9 +187,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) - case ((expr, attr), idx) => + case ((expr, attr), ordinal) => if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { - groupByExprToOutputOrdinal(expr.canonicalized) = idx + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal } attr } @@ -218,8 +218,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val idx = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(idx), expr.dataType) + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + addCastIfNeeded(groupAttrs(ordinal), expr.dataType) } }.asInstanceOf[Seq[NamedExpression]] Project(projectExpressions, scanRelation) From 0b259cfde2e12a1009fe972c0c1195ac79098a2f Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 27 Apr 2022 08:44:37 +0800 Subject: [PATCH 8/9] Update code --- .../sql/execution/datasources/AggregatePushDownUtils.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 11d6a2b426272..97ee3cd661b3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -107,11 +107,11 @@ object AggregatePushDownUtils { // aggregate push down simple and don't handle this complicate case for now. return None } - aggregation.groupByExpressions.flatMap(extractColName).foreach { fieldName => + aggregation.groupByExpressions.map(extractColName).foreach { colName => // don't push down if the group by columns are not the same as the partition columns (orders // doesn't matter because reorder can be done at data source layer) - if (!isPartitionCol(fieldName)) return None - finalSchema = finalSchema.add(getStructFieldForCol(fieldName)) + if (colName.isEmpty || !isPartitionCol(colName.get)) return None + finalSchema = finalSchema.add(getStructFieldForCol(colName.get)) } aggregation.aggregateExpressions.foreach { From b92a231afbca65f23f958578567387e7c67154f4 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 27 Apr 2022 11:16:10 +0800 Subject: [PATCH 9/9] Update code --- .../sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 630afc838a779..8b378d2d87c49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -70,7 +70,7 @@ case class JDBCScanBuilder( private var pushedAggregateList: Array[String] = Array() - private var pushedGroupByCols: Option[Array[String]] = None + private var pushedGroupBys: Option[Array[String]] = None override def supportCompletePushDown(aggregation: Aggregation): Boolean = { lazy val fieldNames = aggregation.groupByExpressions()(0) match { @@ -108,7 +108,7 @@ case class JDBCScanBuilder( try { finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect) pushedAggregateList = selectList - pushedGroupByCols = Some(compiledGroupBys) + pushedGroupBys = Some(compiledGroupBys) true } catch { case NonFatal(e) => @@ -174,6 +174,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) + pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders) } }