Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.spark.annotation.Evolving;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down LIMIT. Please note that the combination of LIMIT with other operations
* such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.SortOrder;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N
* with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc.
* is NOT pushed down.
*
* @since 3.3.0
*/
@Evolving
public interface SupportsPushDownTopN extends ScanBuilder {

/**
* Pushes down top N to the data source.
*/
boolean pushTopN(SortOrder[] orders, int limit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,23 @@ case class RowDataSourceScanExec(
handledFilters
}

val topNOrLimitInfo =
if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) {
val pushedTopN =
s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" +
s" LIMIT ${pushedDownOperators.limit.get}"
Some("pushedTopN" -> pushedTopN)
} else {
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value")
}

Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -336,7 +336,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -410,7 +410,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -433,7 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down Expand Up @@ -726,6 +726,21 @@ object DataSourceStrategy
}
}

protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
sortOrders.map {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, All. This broke Scala 2.13 compilation.

[error] /home/runner/work/spark/spark/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala:730:20: match may not be exhaustive.
[error] It would fail on the following input: SortOrder(_, _, _, _)
[error]     sortOrders.map {
[error]                    ^
[warn] 24 warnings found
[error] one error found
[error] (sql / Compile / compileIncremental) Compilation failed
[error] Total time: 267 s (04:27), completed Dec 15, 2021 5:57:25 AM

case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
Copy link
Member

@dongjoon-hyun dongjoon-hyun Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, this will cause scala.MatchError in Scala 2.12. We need a new test case which is not matched this case, @beliefer .

val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
}
val nullOrderingV2 = nullOrderingV1 match {
case NullsFirst => NullOrdering.NULLS_FIRST
case NullsLast => NullOrdering.NULLS_LAST
}
SortValue(FieldReference(name), directionV2, nullOrderingV2)
}
}

/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -151,12 +152,14 @@ object JDBCRDD extends Logging {
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
* @param sample - The pushed down tableSample.
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
* is not pushed down.
* @param sample - The pushed down tableSample.
* @param sortValues - The sort values cooperates with limit to realize top N.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
// scalastyle:off argcount
def scanTable(
sc: SparkContext,
schema: StructType,
Expand All @@ -167,7 +170,8 @@ object JDBCRDD extends Logging {
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None,
sample: Option[TableSampleInfo] = None,
limit: Int = 0): RDD[InternalRow] = {
limit: Int = 0,
sortValues: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
Expand All @@ -187,8 +191,10 @@ object JDBCRDD extends Logging {
options,
groupByColumns,
sample,
limit)
limit,
sortValues)
}
// scalastyle:on argcount
}

/**
Expand All @@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD(
options: JDBCOptions,
groupByColumns: Option[Array[String]],
sample: Option[TableSampleInfo],
limit: Int)
limit: Int,
sortValues: Array[SortOrder])
extends RDD[InternalRow](sc, Nil) {

/**
Expand Down Expand Up @@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD(
}
}

private def getOrderByClause: String = {
if (sortValues.nonEmpty) {
s" ORDER BY ${sortValues.map(_.describe()).mkString(", ")}"
} else {
""
}
}

/**
* Runs the SQL query against the JDBC driver.
*
Expand Down Expand Up @@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD(
val myLimitClause: String = dialect.getLimitClause(limit)

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
s" $myWhereClause $getGroupByClause $myLimitClause"
s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation(
filters: Array[Filter],
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
limit: Int): RDD[Row] = {
limit: Int,
sortValues: Array[SortOrder]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
Expand All @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation(
Some(finalSchema),
groupByColumns,
tableSample,
limit).asInstanceOf[RDD[Row]]
limit,
sortValues).asInstanceOf[RDD[Row]]
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -157,6 +157,17 @@ object PushDownUtils extends PredicateHelper {
}
}

/**
* Pushes down top N to the data source Scan
*/
def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = {
scanBuilder match {
case s: SupportsPushDownTopN =>
s.pushTopN(order, limit)
case _ => false
}
}

/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation

/**
Expand All @@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
case class PushedDownOperators(
aggregation: Option[Aggregation],
sample: Option[TableSampleInfo],
limit: Option[Int])
limit: Option[Int],
sortValues: Seq[SortOrder]) {
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
Expand Down Expand Up @@ -246,17 +247,35 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match {
case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit)
if (limitPushed) {
sHolder.pushedLimit = Some(limit)
}
operation
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
if filter.isEmpty =>
val orders = DataSourceStrategy.translateSortOrders(order)
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
if (topNPushed) {
sHolder.pushedLimit = Some(limit)
sHolder.sortValues = orders
operation
} else {
s
}
case p: Project =>
val newChild = pushDownLimit(p.child, limit)
p.withNewChildren(Seq(newChild))
case other => other
}

def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
if (limitPushed) {
sHolder.pushedLimit = Some(limitValue)
}
globalLimit
case _ => globalLimit
}
val newChild = pushDownLimit(child, limitValue)
val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild))
globalLimit.withNewChildren(Seq(newLocalLimit))
}

private def getWrappedScan(
Expand All @@ -270,8 +289,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
val pushedDownOperators =
PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit)
val pushedDownOperators = PushedDownOperators(aggregation,
sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortValues)
V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
case _ => scan
}
Expand All @@ -284,6 +303,8 @@ case class ScanBuilderHolder(
builder: ScanBuilder) extends LeafNode {
var pushedLimit: Option[Int] = None

var sortValues: Seq[SortOrder] = Seq.empty[SortOrder]

var pushedSample: Option[TableSampleInfo] = None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand All @@ -31,7 +32,8 @@ case class JDBCScan(
pushedAggregateColumn: Array[String] = Array(),
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
pushedLimit: Int) extends V1Scan {
pushedLimit: Int,
sortValues: Array[SortOrder]) extends V1Scan {

override def readSchema(): StructType = prunedSchema

Expand All @@ -46,8 +48,8 @@ case class JDBCScan(
} else {
pushedAggregateColumn
}
relation.buildScan(
columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit)
relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample,
pushedLimit, sortValues)
}
}.asInstanceOf[T]
}
Expand Down
Loading