diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1b14884e75994..5cec7c634f57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -777,8 +777,8 @@ object DataSourceStrategy } protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = { - def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match { - case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + def translateSortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match { + case SortOrder(PushableExpression(expr), directionV1, nullOrderingV1, _) => val directionV2 = directionV1 match { case Ascending => SortDirection.ASCENDING case Descending => SortDirection.DESCENDING @@ -787,11 +787,11 @@ object DataSourceStrategy case NullsFirst => NullOrdering.NULLS_FIRST case NullsLast => NullOrdering.NULLS_LAST } - Some(SortValue(FieldReference(name), directionV2, nullOrderingV2)) + Some(SortValue(expr, directionV2, nullOrderingV2)) case _ => None } - sortOrders.flatMap(translateOortOrder) + sortOrders.flatMap(translateSortOrder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b30b460ac67db..13d6156aed16d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} @@ -123,7 +122,7 @@ object JDBCRDD extends Logging { groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, limit: Int = 0, - sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { + sortOrders: Array[String] = Array.empty[String]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -166,7 +165,7 @@ private[jdbc] class JDBCRDD( groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], limit: Int, - sortOrders: Array[SortOrder]) + sortOrders: Array[String]) extends RDD[InternalRow](sc, Nil) { /** @@ -216,7 +215,7 @@ private[jdbc] class JDBCRDD( private def getOrderByClause: String = { if (sortOrders.nonEmpty) { - s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}" + s" ORDER BY ${sortOrders.mkString(", ")}" } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 0f1a1b6dc667b..ea8410276072f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} -import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -305,7 +304,7 @@ private[sql] case class JDBCRelation( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], limit: Int, - sortOrders: Array[SortOrder]): RDD[Row] = { + sortOrders: Array[String]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 6455e25089276..6835f4a763821 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -372,7 +372,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] - val orders = DataSourceStrategy.translateSortOrders(newOrder) + val normalizedOrders = DataSourceStrategy.normalizeExprs( + newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]] + val orders = DataSourceStrategy.translateSortOrders(normalizedOrders) if (orders.length == order.length) { val (isPushed, isPartiallyPushed) = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index f68f78d51fd96..5ca23e550aae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation @@ -34,7 +33,7 @@ case class JDBCScan( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], pushedLimit: Int, - sortOrders: Array[SortOrder]) extends V1Scan { + sortOrders: Array[String]) extends V1Scan { override def readSchema(): StructType = prunedSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 0a1542a42956d..9ddb5ea675c73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -53,7 +53,7 @@ case class JDBCScanBuilder( private var pushedLimit = 0 - private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] + private var sortOrders: Array[String] = Array.empty[String] override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { if (jdbcOptions.pushDownPredicate) { @@ -139,8 +139,14 @@ case class JDBCScanBuilder( override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { if (jdbcOptions.pushDownLimit) { + val dialect = JdbcDialects.get(jdbcOptions.url) + val compiledOrders = orders.flatMap { order => + dialect.compileExpression(order.expression()) + .map(sortKey => s"$sortKey ${order.direction()} ${order.nullOrdering()}") + } + if (orders.length != compiledOrders.length) return false pushedLimit = limit - sortOrders = orders + sortOrders = compiledOrders return true } false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5cfa2f465a2be..173e00e4da9a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -222,7 +222,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df1) checkLimitRemoved(df1) checkPushedInfo(df1, - "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") + "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read @@ -237,7 +237,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df2) checkLimitRemoved(df2) checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + - "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read @@ -252,7 +252,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df3, false) checkLimitRemoved(df3, false) checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + - "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ") + "PushedTopN: ORDER BY [SALARY DESC NULLS LAST] LIMIT 1, ") checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df4 = @@ -261,7 +261,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df4) checkLimitRemoved(df4) checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + - "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ") + "PushedTopN: ORDER BY [SALARY ASC NULLS LAST] LIMIT 1, ") checkAnswer(df4, Seq(Row("david"))) val df5 = spark.read.table("h2.test.employee") @@ -304,6 +304,38 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkLimitRemoved(df8, false) checkPushedInfo(df8, "PushedFilters: [], ") checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) + + val df9 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"name", $"SALARY", + when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) + .sort("key", "dept", "SALARY") + .limit(3) + checkSortRemoved(df9) + checkLimitRemoved(df9) + checkPushedInfo(df9, "PushedFilters: [], " + + "PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " + + "(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ") + checkAnswer(df9, + Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) + + val df10 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"DEPT", $"name", $"SALARY", + when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) + .orderBy($"key", $"dept", $"SALARY") + .limit(3) + checkSortRemoved(df10, false) + checkLimitRemoved(df10, false) + checkPushedInfo(df10, "PushedFilters: [], " + + "PushedTopN: ORDER BY [CASE WHEN (SALARY > 8000.00) AND " + + "(SALARY < 10000.00) THEN SALARY ELSE 0.00 END ASC NULL..., ") + checkAnswer(df10, + Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) } test("simple scan with top N: order by with alias") {