From 482976740e6597dce91cff7bb7b6b1eb37f14fc9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Dec 2021 23:02:18 +0800 Subject: [PATCH 1/5] Add a class to represent general aggregate functions in DS V2 --- .../sql/connector/expressions/Expression.java | 2 +- .../expressions/aggregate/Count.java | 3 - .../expressions/aggregate/CountStar.java | 3 - .../aggregate/GeneralAggregateFunc.java | 66 +++++++++++++++++++ .../connector/expressions/aggregate/Max.java | 3 - .../connector/expressions/aggregate/Min.java | 3 - .../connector/expressions/aggregate/Sum.java | 3 - .../connector/expressions/filter/Filter.java | 3 - .../read/SupportsPushDownAggregates.java | 21 +++--- .../connector/expressions/expressions.scala | 20 ++---- .../expressions/TransformExtractorSuite.scala | 8 +-- .../sql/execution/DataSourceScanExec.scala | 4 +- .../datasources/DataSourceStrategy.scala | 6 +- .../v2/V2ScanRelationPushDown.scala | 16 +++-- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 2 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 8 ++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 16 ++--- 17 files changed, 120 insertions(+), 67 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java index 6540c91597582..9f6c0975ae0e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java @@ -29,5 +29,5 @@ public interface Expression { /** * Format the expression as a human readable SQL-like string. */ - String describe(); + default String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 1273886e297bf..1685770604a46 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -46,7 +46,4 @@ public String toString() { return "COUNT(" + column.describe() + ")"; } } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java index f566ad164b8ef..13801194b63cb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java @@ -32,7 +32,4 @@ public CountStar() { @Override public String toString() { return "COUNT(*)"; } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java new file mode 100644 index 0000000000000..e0d95cfaafbb0 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * The general implementation of {@link AggregateFunc}, which contains the upper-cased function + * name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial + * aggregate with this function to the source, but can only push down the entire aggregate. + *

+ * The currently supported SQL aggregate functions: + *

    + *
  1. AVG(input1)
    Since 3.3.0
  2. + *
+ * + * @since 3.3.0 + */ +@Evolving +public final class GeneralAggregateFunc implements AggregateFunc { + private final String name; + private final boolean isDistinct; + private final NamedReference[] inputs; + + public String name() { return name; } + public boolean isDistinct() { return isDistinct; } + public NamedReference[] inputs() { return inputs; } + + public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) { + this.name = name; + this.isDistinct = isDistinct; + this.inputs = inputs; + } + + @Override + public String toString() { + String inputsString = Arrays.stream(inputs) + .map(Expression::describe) + .collect(Collectors.joining(", ")); + if (isDistinct) { + return name + "(DISTINCT " + inputsString + ")"; + } else { + return name + "(" + inputsString + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index ed07cc9e32187..5acdf14bf7e2f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -35,7 +35,4 @@ public final class Max implements AggregateFunc { @Override public String toString() { return "MAX(" + column.describe() + ")"; } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 2e761037746fb..824c607ea7df0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -35,7 +35,4 @@ public final class Min implements AggregateFunc { @Override public String toString() { return "MIN(" + column.describe() + ")"; } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 057ebd89f7a19..6b04dc38c2846 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -46,7 +46,4 @@ public String toString() { return "SUM(" + column.describe() + ")"; } } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java index aa1fa082dc92c..af87e76d2ff7d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java @@ -37,7 +37,4 @@ public abstract class Filter implements Expression, Serializable { * Returns list of columns that are referenced by this filter. */ public abstract NamedReference[] references(); - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 4e6c59e2881fb..1b178d7f2be74 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -22,18 +22,19 @@ /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down aggregates. Spark assumes that the data source can't fully complete the - * grouping work, and will group the data source output again. For queries like - * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate - * to the data source, the data source can still output data with duplicated keys, which is OK - * as Spark will do GROUP BY key again. The final query plan can be something like this: + * push down aggregates. + *

+ * If the data source can't fully complete the grouping work, then + * {@link #supportCompletePushDown()} should return false, and Spark will group the data source + * output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down + * the aggregate to the data source, the data source can still output data with duplicated keys, + * which is OK as Spark will do GROUP BY key again. The final query plan can be something like this: *

- *   Aggregate [key#1], [min(min(value)#2) AS m#3]
- *     +- RelationV2[key#1, min(value)#2]
+ *   Aggregate [key#1], [min(min_value#2) AS m#3]
+ *     +- RelationV2[key#1, min_value#2]
  * 
* Similarly, if there is no grouping expression, the data source can still output more than one * rows. - * *

* When pushing down operators, Spark pushes down filters to the data source first, then push down * aggregates or apply column pruning. Depends on data source implementation, aggregates may or @@ -46,8 +47,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder { /** - * Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg - * and final-agg when the aggregation operation can be pushed down to the datasource completely. + * Whether the datasource support complete aggregation push-down. Spark will do grouping again + * if this method returns false. * * @return true if the aggregation can be pushed down to datasource completely, false otherwise. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index e52654ac69c96..e3eab6f6730f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -88,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R override def arguments: Array[Expression] = Array(ref) - override def describe: String = name + "(" + reference.describe + ")" - - override def toString: String = describe + override def toString: String = name + "(" + reference.describe + ")" protected def withNewRef(ref: NamedReference): Transform @@ -114,7 +112,7 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def describe: String = + override def toString: String = if (sortedColumns.nonEmpty) { s"bucket(${arguments.map(_.describe).mkString(", ")}," + s" ${sortedColumns.map(_.describe).mkString(", ")})" @@ -122,8 +120,6 @@ private[sql] final case class BucketTransform( s"bucket(${arguments.map(_.describe).mkString(", ")})" } - override def toString: String = describe - override def withReferences(newReferences: Seq[NamedReference]): Transform = { this.copy(columns = newReferences) } @@ -169,9 +165,7 @@ private[sql] final case class ApplyTransform( arguments.collect { case named: NamedReference => named } } - override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})" - - override def toString: String = describe + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" } /** @@ -338,21 +332,19 @@ private[sql] object HoursTransform { } private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { - override def describe: String = { + override def toString: String = { if (dataType.isInstanceOf[StringType]) { s"'$value'" } else { s"$value" } } - override def toString: String = describe } private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def fieldNames: Array[String] = parts.toArray - override def describe: String = parts.quoted - override def toString: String = describe + override def toString: String = parts.quoted } private[sql] object FieldReference { @@ -366,7 +358,7 @@ private[sql] final case class SortValue( direction: SortDirection, nullOrdering: NullOrdering) extends SortOrder { - override def describe(): String = s"$expression $direction $nullOrdering" + override def toString(): String = s"$expression $direction $nullOrdering" } private[sql] object SortValue { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index 340d225f80fdb..b2371ce667ffc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -28,7 +28,7 @@ class TransformExtractorSuite extends SparkFunSuite { private def lit[T](literal: T): Literal[T] = new Literal[T] { override def value: T = literal override def dataType: DataType = catalyst.expressions.Literal(literal).dataType - override def describe: String = literal.toString + override def toString: String = literal.toString } /** @@ -36,7 +36,7 @@ class TransformExtractorSuite extends SparkFunSuite { */ private def ref(names: String*): NamedReference = new NamedReference { override def fieldNames: Array[String] = names.toArray - override def describe: String = names.mkString(".") + override def toString: String = names.mkString(".") } /** @@ -46,7 +46,7 @@ class TransformExtractorSuite extends SparkFunSuite { override def name: String = func override def references: Array[NamedReference] = Array(ref) override def arguments: Array[Expression] = Array(ref) - override def describe: String = ref.describe + override def toString: String = ref.describe } test("Identity extractor") { @@ -135,7 +135,7 @@ class TransformExtractorSuite extends SparkFunSuite { override def name: String = "bucket" override def references: Array[NamedReference] = Array(col) override def arguments: Array[Expression] = Array(lit(16), col) - override def describe: String = s"bucket(16, ${col.describe})" + override def toString: String = s"bucket(16, ${col.describe})" } bucketTransform match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 79f2b981b6499..4bd6c239a3367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -157,8 +157,8 @@ case class RowDataSourceScanExec( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => - Map("PushedAggregates" -> seqToString(v.aggregateExpressions), - "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ + Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), + "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++ topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c296ba9f29dd3..fa5429678c1db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -717,8 +717,10 @@ object DataSourceStrategy Some(new Count(FieldReference(name), agg.isDistinct)) case _ => None } - case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => + case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), agg.isDistinct)) + case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name)))) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 3a792f4660ff4..1918dc935c95b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources @@ -109,6 +109,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { r, normalizedAggregates, normalizedGroupingExpressions) if (pushedAggregates.isEmpty) { aggNode // return original plan node + } else if (!supportPartialAggPushDown(pushedAggregates.get) && + !r.supportCompletePushDown()) { + aggNode // return original plan node } else { // No need to do column pruning because only the aggregate columns are used as // DataSourceV2ScanRelation output columns. All the other columns are not @@ -145,9 +148,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { """.stripMargin) val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - if (r.supportCompletePushDown()) { val projectExpressions = resultExpressions.map { expr => // TODO At present, only push down group by attribute is supported. @@ -209,6 +210,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { + // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. + agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) + } + private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = if (aggAttribute.dataType == aggDataType) { aggAttribute @@ -256,7 +262,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { case sample: Sample => sample.child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => + case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => val tableSample = TableSampleInfo( sample.lowerBound, sample.upperBound, @@ -282,7 +288,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } operation case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty => + if filter.isEmpty => val orders = DataSourceStrategy.translateSortOrders(order) if (orders.length == order.length) { val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 01722e883831f..2d01a3e6842b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -79,7 +79,7 @@ case class JDBCScanBuilder( if (!jdbcOptions.pushDownAggregate) return false val dialect = JdbcDialects.get(jdbcOptions.url) - val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_)) + val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate) if (compiledAggs.length != aggregation.aggregateExpressions.length) return false val groupByCols = aggregation.groupByColumns.map { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e0f11afcc2550..344842d30b232 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -219,7 +219,11 @@ abstract class JdbcDialect extends Serializable with Logging{ val column = quoteIdentifier(sum.column.fieldNames.head) Some(s"SUM($distinct$column)") case _: CountStar => - Some(s"COUNT(*)") + Some("COUNT(*)") + case f: GeneralAggregateFunc if f.name() == "AVG" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"AVG($distinct${f.inputs().head})") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3921b3b04a91d..bede7cf1d3f00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -386,20 +386,20 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - test("scan with aggregate push-down: MAX MIN with filter and group by") { - val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + test("scan with aggregate push-down: MAX AVG with filter and group by") { + val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt") checkFiltersRemoved(df) checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) + checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } private def checkFiltersRemoved(df: DataFrame): Unit = { @@ -409,19 +409,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(filters.isEmpty) } - test("scan with aggregate push-down: MAX MIN with filter without group by") { - val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0") + test("scan with aggregate push-down: MAX AVG with filter without group by") { + val df = sql("select MAX(ID), AVG(ID) FROM h2.test.people where id > 0") checkFiltersRemoved(df) checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [MAX(ID), MIN(ID)], " + + "PushedAggregates: [MAX(ID), AVG(ID)], " + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + "PushedGroupByColumns: []" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(2, 1))) + checkAnswer(df, Seq(Row(2, 1.0))) } test("scan with aggregate push-down: aggregate + number") { From 127e41253ab0ae78690f9366ca4e46b3cc029651 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 31 Dec 2021 13:23:39 +0800 Subject: [PATCH 2/5] Update sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 1918dc935c95b..06dd4fbe42d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -212,7 +212,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) + agg.aggregateExpressions().exists(_.isInstanceOf[GeneralAggregateFunc]) } private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = From 0eed98e79a831ad33df1be09db08fb4ad0256485 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Jan 2022 12:57:57 +0800 Subject: [PATCH 3/5] Update sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 06dd4fbe42d28..bbf8abe5ca1e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -212,7 +212,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - agg.aggregateExpressions().exists(_.isInstanceOf[GeneralAggregateFunc]) + !agg.aggregateExpressions().exists(_.isInstanceOf[GeneralAggregateFunc]) } private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = From 6dc33706253b652a678ec04d769c40b4128d9dda Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Jan 2022 12:59:59 +0800 Subject: [PATCH 4/5] Update sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index bbf8abe5ca1e1..1918dc935c95b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -212,7 +212,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - !agg.aggregateExpressions().exists(_.isInstanceOf[GeneralAggregateFunc]) + agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) } private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = From a296a2c4fc0658ed760b24477f281dbecd2dcabf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Jan 2022 13:10:57 +0800 Subject: [PATCH 5/5] one more test --- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index bede7cf1d3f00..0d54a21bf7919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -424,6 +424,23 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(2, 1.0))) } + test("partitioned scan with aggregate push-down: complete push-down only") { + withTempView("v") { + spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .createTempView("v") + val df = sql("select AVG(SALARY) FROM v GROUP BY name") + // Partitioned JDBC Scan doesn't support complete aggregate push-down, and AVG requires + // complete push-down so aggregate is not pushed at the end. + checkAggregateRemoved(df, removed = false) + checkAnswer(df, Seq(Row(9000.0), Row(10000.0), Row(10000.0), Row(12000.0), Row(12000.0))) + } + } + test("scan with aggregate push-down: aggregate + number") { val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee") checkAggregateRemoved(df)