diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 1b178d7f2be74..4d88ec19c897b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -25,10 +25,11 @@ * push down aggregates. *
* If the data source can't fully complete the grouping work, then - * {@link #supportCompletePushDown()} should return false, and Spark will group the data source - * output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down - * the aggregate to the data source, the data source can still output data with duplicated keys, - * which is OK as Spark will do GROUP BY key again. The final query plan can be something like this: + * {@link #supportCompletePushDown(Aggregation)} should return false, and Spark will group the data + * source output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after + * pushing down the aggregate to the data source, the data source can still output data with + * duplicated keys, which is OK as Spark will do GROUP BY key again. The final query plan can be + * something like this: *
* Aggregate [key#1], [min(min_value#2) AS m#3]
* +- RelationV2[key#1, min_value#2]
@@ -50,15 +51,17 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
* Whether the datasource support complete aggregation push-down. Spark will do grouping again
* if this method returns false.
*
+ * @param aggregation Aggregation in SQL statement.
* @return true if the aggregation can be pushed down to datasource completely, false otherwise.
*/
- default boolean supportCompletePushDown() { return false; }
+ default boolean supportCompletePushDown(Aggregation aggregation) { return false; }
/**
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in
* the given Aggregation).
*
+ * @param aggregation Aggregation in SQL statement.
* @return true if the aggregation can be pushed down to datasource, false otherwise.
*/
boolean pushAggregation(Aggregation aggregation);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 1918dc935c95b..dec7189ac698d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -110,7 +110,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
- !r.supportCompletePushDown()) {
+ !r.supportCompletePushDown(pushedAggregates.get)) {
aggNode // return original plan node
} else {
// No need to do column pruning because only the aggregate columns are used as
@@ -149,7 +149,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
- if (r.supportCompletePushDown()) {
+ if (r.supportCompletePushDown(pushedAggregates.get)) {
val projectExpressions = resultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
// In future, more attribute conversion is extended here. e.g. GetStructField
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 2d01a3e6842b3..61bf729bc8fbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -72,8 +72,12 @@ case class JDBCScanBuilder(
private var pushedGroupByCols: Option[Array[String]] = None
- override def supportCompletePushDown: Boolean =
- jdbcOptions.numPartitions.map(_ == 1).getOrElse(true)
+ override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
+ lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
+ jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
+ (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
+ jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
+ }
override def pushAggregation(aggregation: Aggregation): Boolean = {
if (!jdbcOptions.pushDownAggregate) return false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 0d54a21bf7919..9d37a85a2c916 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -685,6 +685,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(query, Seq(Row(47100.0)))
}
+ test("scan with aggregate push-down: partition columns are same as group by columns") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"dept")
+ .count()
+ checkAggregateRemoved(df)
+ checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
+ }
+
test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
@@ -730,4 +743,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
checkAnswer(df, Seq(Row(1), Row(2), Row(2)))
}
+
+ test("scan with aggregate push-down: partition columns with multi group by columns") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"dept", $"name")
+ .count()
+ checkAggregateRemoved(df, false)
+ checkAnswer(df, Seq(Row(1, "amy", 1), Row(1, "cathy", 1),
+ Row(2, "alex", 1), Row(2, "david", 1), Row(6, "jen", 1)))
+ }
+
+ test("scan with aggregate push-down: partition columns is different from group by columns") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"name")
+ .count()
+ checkAggregateRemoved(df, false)
+ checkAnswer(df,
+ Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1)))
+ }
}