diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index 0212895fde079..cba1592c4fa14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -35,4 +35,10 @@ public interface SupportsPushDownTopN extends ScanBuilder { * Pushes down top N to the data source. */ boolean pushTopN(SortOrder[] orders, int limit); + + /** + * Whether the top N is partially pushed or not. If it returns true, then Spark will do top N + * again. This method will only be called when {@link #pushTopN} returns true. + */ + default boolean isPartiallyPushed() { return true; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1149bff7d2da7..f72310b5d7afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -129,11 +129,14 @@ object PushDownUtils extends PredicateHelper { /** * Pushes down top N to the data source Scan */ - def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = { + def pushTopN( + scanBuilder: ScanBuilder, + order: Array[SortOrder], + limit: Int): (Boolean, Boolean) = { scanBuilder match { - case s: SupportsPushDownTopN => - s.pushTopN(order, limit) - case _ => false + case s: SupportsPushDownTopN if s.pushTopN(order, limit) => + (true, s.isPartiallyPushed) + case _ => (false, false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index eaa30f90b77f5..c8ef8b00d0cf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -381,11 +381,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] val orders = DataSourceStrategy.translateSortOrders(newOrder) if (orders.length == order.length) { - val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) - if (topNPushed) { + val (isPushed, isPartiallyPushed) = + PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (isPushed) { sHolder.pushedLimit = Some(limit) sHolder.sortOrders = orders - operation + if (isPartiallyPushed) { + s + } else { + operation + } } else { s } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 475f563856f82..0a1542a42956d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -146,6 +146,8 @@ case class JDBCScanBuilder( false } + override def isPartiallyPushed(): Boolean = jdbcOptions.numPartitions.map(_ > 1).getOrElse(false) + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3ab87ee3387e4..afbdc604b8a18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -199,8 +199,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) - val df2 = spark.read.table("h2.test.employee") - .where($"dept" === 1).orderBy($"salary").limit(1) + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .where($"dept" === 1) + .orderBy($"salary") + .limit(1) checkSortRemoved(df2) checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") @@ -215,7 +222,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) - checkSortRemoved(df3) + checkSortRemoved(df3, false) checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ") checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false)))