Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,8 @@ object DataSourceStrategy
}

protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = {
def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match {
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
def translateSortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match {
case SortOrder(PushableExpression(expr), directionV1, nullOrderingV1, _) =>
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
Expand All @@ -787,11 +787,11 @@ object DataSourceStrategy
case NullsFirst => NullOrdering.NULLS_FIRST
case NullsLast => NullOrdering.NULLS_LAST
}
Some(SortValue(FieldReference(name), directionV2, nullOrderingV2))
Some(SortValue(expr, directionV2, nullOrderingV2))
case _ => None
}

sortOrders.flatMap(translateOortOrder)
sortOrders.flatMap(translateSortOrder)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
Expand Down Expand Up @@ -123,7 +122,7 @@ object JDBCRDD extends Logging {
groupByColumns: Option[Array[String]] = None,
sample: Option[TableSampleInfo] = None,
limit: Int = 0,
sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = {
sortOrders: Array[String] = Array.empty[String]): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
Expand Down Expand Up @@ -166,7 +165,7 @@ private[jdbc] class JDBCRDD(
groupByColumns: Option[Array[String]],
sample: Option[TableSampleInfo],
limit: Int,
sortOrders: Array[SortOrder])
sortOrders: Array[String])
extends RDD[InternalRow](sc, Nil) {

/**
Expand Down Expand Up @@ -216,7 +215,7 @@ private[jdbc] class JDBCRDD(

private def getOrderByClause: String = {
if (sortOrders.nonEmpty) {
s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}"
s" ORDER BY ${sortOrders.mkString(", ")}"
} else {
""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -305,7 +304,7 @@ private[sql] case class JDBCRelation(
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
limit: Int,
sortOrders: Array[SortOrder]): RDD[Row] = {
sortOrders: Array[String]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
order, project, alwaysInline = true) =>
val aliasMap = getAliasMap(project)
val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]]
val orders = DataSourceStrategy.translateSortOrders(newOrder)
val normalizedOrders = DataSourceStrategy.normalizeExprs(
newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
if (orders.length == order.length) {
val (isPushed, isPartiallyPushed) =
PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
Expand All @@ -34,7 +33,7 @@ case class JDBCScan(
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
pushedLimit: Int,
sortOrders: Array[SortOrder]) extends V1Scan {
sortOrders: Array[String]) extends V1Scan {

override def readSchema(): StructType = prunedSchema

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ case class JDBCScanBuilder(

private var pushedLimit = 0

private var sortOrders: Array[SortOrder] = Array.empty[SortOrder]
private var sortOrders: Array[String] = Array.empty[String]

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
if (jdbcOptions.pushDownPredicate) {
Expand Down Expand Up @@ -139,8 +139,14 @@ case class JDBCScanBuilder(

override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = {
if (jdbcOptions.pushDownLimit) {
val dialect = JdbcDialects.get(jdbcOptions.url)
val compiledOrders = orders.flatMap { order =>
dialect.compileExpression(order.expression())
.map(sortKey => s"$sortKey ${order.direction()} ${order.nullOrdering()}")
}
if (orders.length != compiledOrders.length) return false
pushedLimit = limit
sortOrders = orders
sortOrders = compiledOrders
return true
}
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkSortRemoved(df1)
checkLimitRemoved(df1)
checkPushedInfo(df1,
"PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
"PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ")
checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))

val df2 = spark.read
Expand All @@ -237,7 +237,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkSortRemoved(df2)
checkLimitRemoved(df2)
checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " +
"PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
"PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ")
checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))

val df3 = spark.read
Expand All @@ -252,7 +252,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkSortRemoved(df3, false)
checkLimitRemoved(df3, false)
checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
"PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ")
"PushedTopN: ORDER BY [SALARY DESC NULLS LAST] LIMIT 1, ")
checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df4 =
Expand All @@ -261,7 +261,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkSortRemoved(df4)
checkLimitRemoved(df4)
checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
"PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ")
"PushedTopN: ORDER BY [SALARY ASC NULLS LAST] LIMIT 1, ")
checkAnswer(df4, Seq(Row("david")))

val df5 = spark.read.table("h2.test.employee")
Expand Down Expand Up @@ -304,6 +304,38 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkLimitRemoved(df8, false)
checkPushedInfo(df8, "PushedFilters: [], ")
checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df9 = spark.read
.table("h2.test.employee")
.select($"DEPT", $"name", $"SALARY",
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.sort("key", "dept", "SALARY")
.limit(3)
checkSortRemoved(df9)
checkLimitRemoved(df9)
checkPushedInfo(df9, "PushedFilters: [], " +
"PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " +
"(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ")
checkAnswer(df9,
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))

val df10 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.select($"DEPT", $"name", $"SALARY",
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.orderBy($"key", $"dept", $"SALARY")
.limit(3)
checkSortRemoved(df10, false)
checkLimitRemoved(df10, false)
checkPushedInfo(df10, "PushedFilters: [], " +
"PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " +
"(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ")
checkAnswer(df10,
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))
}

test("simple scan with top N: order by with alias") {
Expand Down