Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/sql-data-sources-jdbc.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ logging into the data sources.
<td>read</td>
</tr>

<tr>
<td><code>pushDownLimit</code></td>
<td><code>false</code></td>
<td>
The option to enable or disable LIMIT push-down into the JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down.
</td>
<td>read</td>
</tr>

<tr>
<td><code>keytab</code></td>
<td>(none)</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

/**
* An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
* interfaces to do operator pushdown, and keep the operator pushdown result in the returned
* {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down
* aggregates or applies column pruning.
* interfaces to do operator push down, and keep the operator push down result in the returned
* {@link Scan}. When pushing down operators, the push down order is:
* filter -&gt; aggregate -&gt; limit -&gt; column pruning.
*
* @since 3.0.0
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
* push down LIMIT. Please note that the combination of LIMIT with other operations
* such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
*
* @since 3.3.0
*/
@Evolving
public interface SupportsPushDownLimit extends ScanBuilder {

/**
* Pushes down LIMIT to the data source.
*/
boolean pushLimit(int limit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ case class RowDataSourceScanExec(
filters: Set[Filter],
handledFilters: Set[Filter],
aggregation: Option[Aggregation],
limit: Option[Int],
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
Expand Down Expand Up @@ -153,7 +154,8 @@ case class RowDataSourceScanExec(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> seqToString(markedFilters.toSeq),
"PushedAggregates" -> aggString,
"PushedGroupby" -> groupByString)
"PushedGroupby" -> groupByString) ++
limit.map(value => "PushedLimit" -> s"LIMIT $value")
}

// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ object DataSourceStrategy
Set.empty,
Set.empty,
None,
None,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -410,6 +411,7 @@ object DataSourceStrategy
pushedFilters.toSet,
handledFilters,
None,
None,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -433,6 +435,7 @@ object DataSourceStrategy
pushedFilters.toSet,
handledFilters,
None,
None,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ class JDBCOptions(
// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean

// An option to allow/disallow pushing down LIMIT into JDBC data source
val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean

// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
Expand Down Expand Up @@ -266,6 +269,7 @@ object JDBCOptions {
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ object JDBCRDD extends Logging {
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
* is not pushed down.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
Expand All @@ -190,7 +192,8 @@ object JDBCRDD extends Logging {
parts: Array[Partition],
options: JDBCOptions,
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
groupByColumns: Option[Array[String]] = None,
limit: Int = 0): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
Expand All @@ -208,7 +211,8 @@ object JDBCRDD extends Logging {
parts,
url,
options,
groupByColumns)
groupByColumns,
limit)
}
}

Expand All @@ -226,7 +230,8 @@ private[jdbc] class JDBCRDD(
partitions: Array[Partition],
url: String,
options: JDBCOptions,
groupByColumns: Option[Array[String]])
groupByColumns: Option[Array[String]],
limit: Int)
extends RDD[InternalRow](sc, Nil) {

/**
Expand Down Expand Up @@ -349,8 +354,10 @@ private[jdbc] class JDBCRDD(

val myWhereClause = getWhereClause(part)

val myLimitClause: String = dialect.getLimitClause(limit)

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
s" $getGroupByClause"
s" $getGroupByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ private[sql] case class JDBCRelation(
requiredColumns: Array[String],
finalSchema: StructType,
filters: Array[Filter],
groupByColumns: Option[Array[String]]): RDD[Row] = {
groupByColumns: Option[Array[String]],
limit: Int): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
Expand All @@ -308,7 +309,8 @@ private[sql] case class JDBCRelation(
parts,
jdbcOptions,
Some(finalSchema),
groupByColumns).asInstanceOf[RDD[Row]]
groupByColumns,
limit).asInstanceOf[RDD[Row]]
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,22 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters,
DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) =>
DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate, limit), output)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
scan.readSchema(), v1Relation.schema)
}
val rdd = v1Relation.buildScan()
val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd)

val dsScan = RowDataSourceScanExec(
output,
output.toStructType,
Set.empty,
pushed.toSet,
aggregate,
limit,
unsafeRowRDD,
v1Relation,
tableIdentifier = None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -138,6 +138,17 @@ object PushDownUtils extends PredicateHelper {
}
}

/**
* Pushes down LIMIT to the data source Scan
*/
def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = {
scanBuilder match {
case s: SupportsPushDownLimit =>
s.pushLimit(limit)
case _ => false
}
}

/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
Expand All @@ -36,7 +36,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._

def apply(plan: LogicalPlan): LogicalPlan = {
applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
applyColumnPruning(applyLimit(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))))
}

private def createScanBuilder(plan: LogicalPlan) = plan.transform {
Expand Down Expand Up @@ -225,6 +225,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
withProjection
}

def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
if (limitPushed) {
sHolder.setLimit(Some(limitValue))
}
globalLimit
case _ => globalLimit
}
}

private def getWrappedScan(
scan: Scan,
sHolder: ScanBuilderHolder,
Expand All @@ -236,7 +249,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
V1ScanWrapper(v1, pushedFilters, aggregation)
V1ScanWrapper(v1, pushedFilters, aggregation, sHolder.pushedLimit)
case _ => scan
}
}
Expand All @@ -245,13 +258,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case class ScanBuilderHolder(
output: Seq[AttributeReference],
relation: DataSourceV2Relation,
builder: ScanBuilder) extends LeafNode
builder: ScanBuilder) extends LeafNode {
var pushedLimit: Option[Int] = None
private[sql] def setLimit(limit: Option[Int]): Unit = pushedLimit = limit
}


// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
// the physical v1 scan node.
case class V1ScanWrapper(
v1Scan: V1Scan,
handledFilters: Seq[sources.Filter],
pushedAggregate: Option[Aggregation]) extends Scan {
pushedAggregate: Option[Aggregation],
pushedLimit: Option[Int]) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ case class JDBCScan(
prunedSchema: StructType,
pushedFilters: Array[Filter],
pushedAggregateColumn: Array[String] = Array(),
groupByColumns: Option[Array[String]]) extends V1Scan {
groupByColumns: Option[Array[String]],
pushedLimit: Int) extends V1Scan {

override def readSchema(): StructType = prunedSchema

Expand All @@ -43,7 +44,7 @@ case class JDBCScan(
} else {
pushedAggregateColumn
}
relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns)
relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, pushedLimit)
}
}.asInstanceOf[T]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
import org.apache.spark.sql.jdbc.JdbcDialects
Expand All @@ -36,6 +36,7 @@ case class JDBCScanBuilder(
with SupportsPushDownFilters
with SupportsPushDownRequiredColumns
with SupportsPushDownAggregates
with SupportsPushDownLimit
with Logging {

private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
Expand All @@ -44,6 +45,16 @@ case class JDBCScanBuilder(

private var finalSchema = schema

private var pushedLimit = 0

override def pushLimit(limit: Int): Boolean = {
if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) {
pushedLimit = limit
return true
}
false
}

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (jdbcOptions.pushDownPredicate) {
val dialect = JdbcDialects.get(jdbcOptions.url)
Expand Down Expand Up @@ -123,6 +134,6 @@ case class JDBCScanBuilder(
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter,
pushedAggregateList, pushedGroupByCols)
pushedAggregateList, pushedGroupByCols, pushedLimit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@ private object DerbyDialect extends JdbcDialect {
override def getTableCommentQuery(table: String, comment: String): String = {
throw QueryExecutionErrors.commentOnTableUnsupportedError()
}

// ToDo: use fetch first n rows only for limit, e.g.
// select * from employee fetch first 10 rows only;
override def supportsLimit(): Boolean = false
}
Loading