Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,10 @@ public interface SupportsPushDownLimit extends ScanBuilder {
* Pushes down LIMIT to the data source.
*/
boolean pushLimit(int limit);

/**
* Whether the LIMIT is partially pushed or not. If it returns true, then Spark will do LIMIT
* again. This method will only be called when {@link #pushLimit} returns true.
*/
default boolean isPartiallyPushed() { return true; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ object PushDownUtils extends PredicateHelper {
/**
* Pushes down LIMIT to the data source Scan
*/
def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = {
def pushLimit(scanBuilder: ScanBuilder, limit: Int): (Boolean, Boolean) = {
scanBuilder match {
case s: SupportsPushDownLimit =>
s.pushLimit(limit)
case _ => false
case s: SupportsPushDownLimit if s.pushLimit(limit) =>
(true, s.isPartiallyPushed)
case _ => (false, false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
}
}

private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match {
private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match {
case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit)
if (limitPushed) {
val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit)
if (isPushed) {
sHolder.pushedLimit = Some(limit)
}
operation
(operation, isPushed && !isPartiallyPushed)
case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder))
if filter.isEmpty && CollapseProject.canCollapseExpressions(
order, project, alwaysInline = true) =>
Expand All @@ -380,27 +380,32 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
sHolder.pushedLimit = Some(limit)
sHolder.sortOrders = orders
if (isPartiallyPushed) {
s
(s, false)
} else {
operation
(operation, true)
}
} else {
s
(s, false)
}
} else {
s
(s, false)
}
case p: Project =>
val newChild = pushDownLimit(p.child, limit)
p.withNewChildren(Seq(newChild))
case other => other
val (newChild, isPartiallyPushed) = pushDownLimit(p.child, limit)
(p.withNewChildren(Seq(newChild)), isPartiallyPushed)
case other => (other, false)
}

def pushDownLimits(plan: LogicalPlan): LogicalPlan = plan.transform {
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
val newChild = pushDownLimit(child, limitValue)
val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild))
globalLimit.withNewChildren(Seq(newLocalLimit))
val (newChild, canRemoveLimit) = pushDownLimit(child, limitValue)
if (canRemoveLimit) {
newChild
} else {
val newLocalLimit =
globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild))
globalLimit.withNewChildren(Seq(newLocalLimit))
}
}

private def getWrappedScan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when}
Expand Down Expand Up @@ -141,9 +141,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
assert(scan.schema.names.sameElements(names))
}

private def checkLimitRemoved(df: DataFrame, removed: Boolean = true): Unit = {
val limits = df.queryExecution.optimizedPlan.collect {
case g: GlobalLimit => g
case limit: LocalLimit => limit
}
if (removed) {
assert(limits.isEmpty)
} else {
assert(limits.nonEmpty)
}
}

test("simple scan with LIMIT") {
val df1 = spark.read.table("h2.test.employee")
.where($"dept" === 1).limit(1)
checkLimitRemoved(df1)
checkPushedInfo(df1,
"PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 1, ")
checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true)))
Expand All @@ -156,12 +169,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.table("h2.test.employee")
.filter($"dept" > 1)
.limit(1)
checkLimitRemoved(df2, false)
checkPushedInfo(df2,
"PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ")
checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1")
checkSchemaNames(df3, Seq("NAME"))
checkLimitRemoved(df3)
checkPushedInfo(df3,
"PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ")
checkAnswer(df3, Seq(Row("alex")))
Expand All @@ -170,6 +185,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.limit(1)
checkLimitRemoved(df4, false)
checkPushedInfo(df4,
"PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ")
checkAnswer(df4, Seq(Row(1, 19000.00)))
Expand All @@ -181,6 +197,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
.filter(name($"shortName"))
.limit(1)
checkLimitRemoved(df5, false)
// LIMIT is pushed down only if all the filters are pushed down
checkPushedInfo(df5, "PushedFilters: [], ")
checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy")))
Expand All @@ -203,6 +220,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.sort("salary")
.limit(1)
checkSortRemoved(df1)
checkLimitRemoved(df1)
checkPushedInfo(df1,
"PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
Expand All @@ -217,6 +235,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.orderBy($"salary")
.limit(1)
checkSortRemoved(df2)
checkLimitRemoved(df2)
checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " +
"PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
Expand All @@ -231,6 +250,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.orderBy($"salary".desc)
.limit(1)
checkSortRemoved(df3, false)
checkLimitRemoved(df3, false)
checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
"PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ")
checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
Expand All @@ -239,6 +259,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1")
checkSchemaNames(df4, Seq("NAME"))
checkSortRemoved(df4)
checkLimitRemoved(df4)
checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
"PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ")
checkAnswer(df4, Seq(Row("david")))
Expand All @@ -256,6 +277,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.orderBy("DEPT")
.limit(1)
checkSortRemoved(df6, false)
checkLimitRemoved(df6, false)
checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," +
" PushedFilters: [], PushedGroupByColumns: [DEPT], ")
checkAnswer(df6, Seq(Row(1, 19000.00)))
Expand All @@ -270,6 +292,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.limit(1)
// LIMIT is pushed down only if all the filters are pushed down
checkSortRemoved(df7, false)
checkLimitRemoved(df7, false)
checkPushedInfo(df7, "PushedFilters: [], ")
checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy")))

Expand All @@ -278,6 +301,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.sort(sub($"NAME"))
.limit(1)
checkSortRemoved(df8, false)
checkLimitRemoved(df8, false)
checkPushedInfo(df8, "PushedFilters: [], ")
checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
}
Expand Down