diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 1b178d7f2be74..4d88ec19c897b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -25,10 +25,11 @@ * push down aggregates. *

* If the data source can't fully complete the grouping work, then - * {@link #supportCompletePushDown()} should return false, and Spark will group the data source - * output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down - * the aggregate to the data source, the data source can still output data with duplicated keys, - * which is OK as Spark will do GROUP BY key again. The final query plan can be something like this: + * {@link #supportCompletePushDown(Aggregation)} should return false, and Spark will group the data + * source output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after + * pushing down the aggregate to the data source, the data source can still output data with + * duplicated keys, which is OK as Spark will do GROUP BY key again. The final query plan can be + * something like this: *

  *   Aggregate [key#1], [min(min_value#2) AS m#3]
  *     +- RelationV2[key#1, min_value#2]
@@ -50,15 +51,17 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
    * Whether the datasource support complete aggregation push-down. Spark will do grouping again
    * if this method returns false.
    *
+   * @param aggregation Aggregation in SQL statement.
    * @return true if the aggregation can be pushed down to datasource completely, false otherwise.
    */
-  default boolean supportCompletePushDown() { return false; }
+  default boolean supportCompletePushDown(Aggregation aggregation) { return false; }
 
   /**
    * Pushes down Aggregation to datasource. The order of the datasource scan output columns should
    * be: grouping columns, aggregate columns (in the same order as the aggregate functions in
    * the given Aggregation).
    *
+   * @param aggregation Aggregation in SQL statement.
    * @return true if the aggregation can be pushed down to datasource, false otherwise.
    */
   boolean pushAggregation(Aggregation aggregation);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 1918dc935c95b..dec7189ac698d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -110,7 +110,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
               if (pushedAggregates.isEmpty) {
                 aggNode // return original plan node
               } else if (!supportPartialAggPushDown(pushedAggregates.get) &&
-                !r.supportCompletePushDown()) {
+                !r.supportCompletePushDown(pushedAggregates.get)) {
                 aggNode // return original plan node
               } else {
                 // No need to do column pruning because only the aggregate columns are used as
@@ -149,7 +149,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
                 val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
                 val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
-                if (r.supportCompletePushDown()) {
+                if (r.supportCompletePushDown(pushedAggregates.get)) {
                   val projectExpressions = resultExpressions.map { expr =>
                     // TODO At present, only push down group by attribute is supported.
                     // In future, more attribute conversion is extended here. e.g. GetStructField
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 2d01a3e6842b3..61bf729bc8fbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -72,8 +72,12 @@ case class JDBCScanBuilder(
 
   private var pushedGroupByCols: Option[Array[String]] = None
 
-  override def supportCompletePushDown: Boolean =
-    jdbcOptions.numPartitions.map(_ == 1).getOrElse(true)
+  override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
+    lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
+    jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
+      (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
+        jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
+  }
 
   override def pushAggregation(aggregation: Aggregation): Boolean = {
     if (!jdbcOptions.pushDownAggregate) return false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 0d54a21bf7919..9d37a85a2c916 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -685,6 +685,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(query, Seq(Row(47100.0)))
   }
 
+  test("scan with aggregate push-down: partition columns are same as group by columns") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"dept")
+      .count()
+    checkAggregateRemoved(df)
+    checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
+  }
+
   test("scan with aggregate push-down: aggregate over alias NOT push down") {
     val cols = Seq("a", "b", "c", "d")
     val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
@@ -730,4 +743,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     }
     checkAnswer(df, Seq(Row(1), Row(2), Row(2)))
   }
+
+  test("scan with aggregate push-down: partition columns with multi group by columns") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"dept", $"name")
+      .count()
+    checkAggregateRemoved(df, false)
+    checkAnswer(df, Seq(Row(1, "amy", 1), Row(1, "cathy", 1),
+      Row(2, "alex", 1), Row(2, "david", 1), Row(6, "jen", 1)))
+  }
+
+  test("scan with aggregate push-down: partition columns is different from group by columns") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"name")
+      .count()
+    checkAggregateRemoved(df, false)
+    checkAnswer(df,
+      Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1)))
+  }
 }