diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index fa6447bc068d5..035154d08450a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -33,4 +33,10 @@ public interface SupportsPushDownLimit extends ScanBuilder { * Pushes down LIMIT to the data source. */ boolean pushLimit(int limit); + + /** + * Whether the LIMIT is partially pushed or not. If it returns true, then Spark will do LIMIT + * again. This method will only be called when {@link #pushLimit} returns true. + */ + default boolean isPartiallyPushed() { return true; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index f72310b5d7afa..37c180ef5d353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -118,11 +118,11 @@ object PushDownUtils extends PredicateHelper { /** * Pushes down LIMIT to the data source Scan */ - def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = { + def pushLimit(scanBuilder: ScanBuilder, limit: Int): (Boolean, Boolean) = { scanBuilder match { - case s: SupportsPushDownLimit => - s.pushLimit(limit) - case _ => false + case s: SupportsPushDownLimit if s.pushLimit(limit) => + (true, s.isPartiallyPushed) + case _ => (false, false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 5371829271d36..6455e25089276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -360,13 +360,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } - private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match { + private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match { case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => - val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit) - if (limitPushed) { + val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit) + if (isPushed) { sHolder.pushedLimit = Some(limit) } - operation + (operation, isPushed && !isPartiallyPushed) case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) if filter.isEmpty && CollapseProject.canCollapseExpressions( order, project, alwaysInline = true) => @@ -380,27 +380,32 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit sHolder.pushedLimit = Some(limit) sHolder.sortOrders = orders if (isPartiallyPushed) { - s + (s, false) } else { - operation + (operation, true) } } else { - s + (s, false) } } else { - s + (s, false) } case p: Project => - val newChild = pushDownLimit(p.child, limit) - p.withNewChildren(Seq(newChild)) - case other => other + val (newChild, isPartiallyPushed) = pushDownLimit(p.child, limit) + (p.withNewChildren(Seq(newChild)), isPartiallyPushed) + case other => (other, false) } def pushDownLimits(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => - val newChild = pushDownLimit(child, limitValue) - val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) - globalLimit.withNewChildren(Seq(newLocalLimit)) + val (newChild, canRemoveLimit) = pushDownLimit(child, limitValue) + if (canRemoveLimit) { + newChild + } else { + val newLocalLimit = + globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) + globalLimit.withNewChildren(Seq(newLocalLimit)) + } } private def getWrappedScan( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 67a02904660c3..aeae819205da1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when} @@ -141,9 +141,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(scan.schema.names.sameElements(names)) } + private def checkLimitRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val limits = df.queryExecution.optimizedPlan.collect { + case g: GlobalLimit => g + case limit: LocalLimit => limit + } + if (removed) { + assert(limits.isEmpty) + } else { + assert(limits.nonEmpty) + } + } + test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) + checkLimitRemoved(df1) checkPushedInfo(df1, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true))) @@ -156,12 +169,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .filter($"dept" > 1) .limit(1) + checkLimitRemoved(df2, false) checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") checkSchemaNames(df3, Seq("NAME")) + checkLimitRemoved(df3) checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") checkAnswer(df3, Seq(Row("alex"))) @@ -170,6 +185,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) + checkLimitRemoved(df4, false) checkPushedInfo(df4, "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ") checkAnswer(df4, Seq(Row(1, 19000.00))) @@ -181,6 +197,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .limit(1) + checkLimitRemoved(df5, false) // LIMIT is pushed down only if all the filters are pushed down checkPushedInfo(df5, "PushedFilters: [], ") checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) @@ -203,6 +220,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort("salary") .limit(1) checkSortRemoved(df1) + checkLimitRemoved(df1) checkPushedInfo(df1, "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) @@ -217,6 +235,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .orderBy($"salary") .limit(1) checkSortRemoved(df2) + checkLimitRemoved(df2) checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) @@ -231,6 +250,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .orderBy($"salary".desc) .limit(1) checkSortRemoved(df3, false) + checkLimitRemoved(df3, false) checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ") checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) @@ -239,6 +259,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") checkSchemaNames(df4, Seq("NAME")) checkSortRemoved(df4) + checkLimitRemoved(df4) checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ") checkAnswer(df4, Seq(Row("david"))) @@ -256,6 +277,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .orderBy("DEPT") .limit(1) checkSortRemoved(df6, false) + checkLimitRemoved(df6, false) checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + " PushedFilters: [], PushedGroupByColumns: [DEPT], ") checkAnswer(df6, Seq(Row(1, 19000.00))) @@ -270,6 +292,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df7, false) + checkLimitRemoved(df7, false) checkPushedInfo(df7, "PushedFilters: [], ") checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) @@ -278,6 +301,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort(sub($"NAME")) .limit(1) checkSortRemoved(df8, false) + checkLimitRemoved(df8, false) checkPushedInfo(df8, "PushedFilters: [], ") checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) }