diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index d6e99312bb66e..16c377af3ac6a 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -272,6 +272,7 @@ Below is a list of all the keywords in Spark SQL. |NULL|reserved|non-reserved|reserved| |NULLS|non-reserved|non-reserved|non-reserved| |OF|non-reserved|non-reserved|reserved| +|OFFSET|reserved|non-reserved|reserved| |ON|reserved|strict-non-reserved|reserved| |ONLY|reserved|non-reserved|reserved| |OPTION|non-reserved|non-reserved|non-reserved| diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ad0de528708a4..2c097a527bffa 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -460,6 +460,7 @@ queryOrganization (SORT BY sort+=sortItem (',' sort+=sortItem)*)? windowClause? (LIMIT (ALL | limit=expression))? + (OFFSET offset=expression)? ; multiInsertQueryBody @@ -1358,6 +1359,7 @@ nonReserved | NULL | NULLS | OF + | OFFSET | ONLY | OPTION | OPTIONS @@ -1611,6 +1613,7 @@ NOT: 'NOT' | '!'; NULL: 'NULL'; NULLS: 'NULLS'; OF: 'OF'; +OFFSET: 'OFFSET'; ON: 'ON'; ONLY: 'ONLY'; OPTION: 'OPTION'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 351be32ee438e..09248952f2630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -333,10 +333,30 @@ trait CheckAnalysis extends PredicateHelper { case GlobalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr) - case LocalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr) + case LocalLimit(limitExpr, child) => + checkLimitLikeClause("limit", limitExpr) + child match { + case Offset(offsetExpr, _) => + val limit = limitExpr.eval().asInstanceOf[Int] + val offset = offsetExpr.eval().asInstanceOf[Int] + if (Int.MaxValue - limit < offset) { + failAnalysis( + s"""The sum of limit and offset must not be greater than Int.MaxValue, + | but found limit = $limit, offset = $offset.""".stripMargin) + } + case _ => + } + + case Offset(offsetExpr, _) => checkLimitLikeClause("offset", offsetExpr) case Tail(limitExpr, _) => checkLimitLikeClause("tail", limitExpr) + case o if !o.isInstanceOf[GlobalLimit] && !o.isInstanceOf[LocalLimit] + && o.children.exists(_.isInstanceOf[Offset]) => + failAnalysis( + s"""Only the OFFSET clause is allowed in the LIMIT clause, but the OFFSET + | clause found in: ${o.nodeName}.""".stripMargin) + case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) def ordinalNumber(i: Int): String = i match { @@ -661,6 +681,7 @@ trait CheckAnalysis extends PredicateHelper { } } checkCollectedMetrics(plan) + checkOutermostOffset(plan) extendedCheckRules.foreach(_(plan)) plan.foreachUp { case o if !o.resolved => @@ -786,6 +807,20 @@ trait CheckAnalysis extends PredicateHelper { check(plan) } + /** + * Validate that the root node of query or subquery is [[Offset]]. + */ + private def checkOutermostOffset(plan: LogicalPlan): Unit = { + plan match { + case Offset(offsetExpr, _) => + checkLimitLikeClause("offset", offsetExpr) + failAnalysis( + s"""Only the OFFSET clause is allowed in the LIMIT clause, but the OFFSET + | clause is found to be the outermost node.""".stripMargin) + case _ => + } + } + /** * Validates to make sure the outer references appearing inside the subquery * are allowed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 46f178f3a9ce2..15858912b579a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -366,6 +366,8 @@ object UnsupportedOperationChecker extends Logging { throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + "output mode") + case Offset(_, _) => throwError("Offset is not supported on streaming DataFrames/Datasets") + case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b61c4b8d065f2..0e1dda7788d51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -350,6 +350,8 @@ package object dsl { def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan) + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6033c01a60f47..0de799e45f8c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,6 +85,7 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseWindow, CombineFilters, CombineLimits, + RewriteOffsets, CombineUnions, // Constant folding and strength reduction TransposeWindow, @@ -1449,6 +1450,19 @@ object CombineLimits extends Rule[LogicalPlan] { } } +/** + * Rewrite [[Offset]] as [[Limit]] or combines two adjacent [[Offset]] operators into one, + * merging the expressions into one single expression. + */ +object RewriteOffsets extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case GlobalLimit(le, Offset(oe, grandChild)) => + GlobalLimitAndOffset(le, oe, grandChild) + case LocalLimit(le, Offset(oe, grandChild)) => + Offset(oe, LocalLimit(Add(le, oe), grandChild)) + } +} + /** * Check if there any cartesian products between joins of any type in the optimized plan tree. * Throw an error if a cartesian product is found without an explicit cross join specified. @@ -1554,7 +1568,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { LocalRelation(projectList.map(_.toAttribute), data.map(projection(_).copy()), isStreaming) case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) => - LocalRelation(output, data.take(limit), isStreaming) + LocalRelation(output, data.take(limit), isStreaming) case Filter(condition, LocalRelation(output, data, isStreaming)) if !hasUnevaluableExpr(condition) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 2627202c09c45..3f3befdf88dbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -105,6 +105,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit case _: Sort => empty(p) case _: GlobalLimit if !p.isStreaming => empty(p) case _: LocalLimit if !p.isStreaming => empty(p) + case _: Offset => empty(p) case _: Repartition => empty(p) case _: RepartitionByExpression => empty(p) // An aggregate with non-empty group expression will return one output row per group when the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 0e7a39c54050e..3543716e19568 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -726,6 +726,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: Sample => true case _: GlobalLimit => true case _: LocalLimit => true + case _: Offset => true case _: Generate => true case _: Distinct => true case _: AppendColumns => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f133235a2636e..4acc09c7815ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -561,10 +561,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // WINDOWS val withWindow = withOrder.optionalMap(windowClause)(withWindowClause) + // OFFSET + // - OFFSET 0 is the same as omitting the OFFSET clause + val withOffset = withWindow.optional(offset) { + Offset(typedVisit(offset), withWindow) + } + // LIMIT // - LIMIT ALL is the same as omitting the LIMIT clause - withWindow.optional(limit) { - Limit(typedVisit(limit), withWindow) + withOffset.optional(limit) { + Limit(typedVisit(limit), withOffset) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index 18baced8f3d61..125b6266905b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -30,6 +30,7 @@ trait LogicalPlanVisitor[T] { case p: Filter => visitFilter(p) case p: Generate => visitGenerate(p) case p: GlobalLimit => visitGlobalLimit(p) + case p: Offset => visitOffset(p) case p: Intersect => visitIntersect(p) case p: Join => visitJoin(p) case p: LocalLimit => visitLocalLimit(p) @@ -60,6 +61,8 @@ trait LogicalPlanVisitor[T] { def visitGlobalLimit(p: GlobalLimit): T + def visitOffset(p: Offset): T + def visitIntersect(p: Intersect): T def visitJoin(p: Join): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 223ef652d2f80..dcd60401ceac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -746,6 +746,21 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A logical offset, which may removing a specified number of rows from the beginning of the + * output of child logical plan. + */ +case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + import scala.math.max + offsetExpr match { + case IntegerLiteral(offset) => child.maxRows.map { x => max(x - offset, 0) } + case _ => None + } + } +} + /** * A constructor for creating a pivot, which will later be converted to a [[Project]] * or an [[Aggregate]] during the query analysis. @@ -843,6 +858,23 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr } } +/** + * A global (coordinated) limit with offset. This operator can skip at most `offsetExpr` number and + * emit at most `limitExpr` number in total. + */ +case class GlobalLimitAndOffset( + limitExpr: Expression, + offsetExpr: Expression, + child: LogicalPlan) extends OrderPreservingUnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } +} + /** * This is similar with [[Limit]] except: * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index b8c652dc8f12e..8ccdb1d576d4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -47,6 +47,8 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p) + override def visitOffset(p: Offset): Statistics = fallback(p) + override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index da36db7ae1f5f..0c6e3a2a13bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -89,6 +89,15 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { rowCount = Some(rowCount)) } + override def visitOffset(p: Offset): Statistics = { + val offset = p.offsetExpr.eval().asInstanceOf[Int] + val childStats = p.child.stats + val rowCount: BigInt = childStats.rowCount.map(_.-(offset).max(0)).getOrElse(0) + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount)) + } + override def visitIntersect(p: Intersect): Statistics = { val leftSize = p.left.stats.sizeInBytes val rightSize = p.right.stats.sizeInBytes diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d3a14e511cdc2..70d50814803a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -494,6 +494,38 @@ class AnalysisErrorSuite extends AnalysisTest { "The limit expression must be equal to or greater than 0, but got -1" :: Nil ) + errorTest( + "an evaluated offset class must not be null", + testRelation.offset(Literal(null, IntegerType)), + "The evaluated offset expression must not be null, but got " :: Nil + ) + + errorTest( + "num_rows in offset clause must be equal to or greater than 0", + listRelation.offset(-1), + "The offset expression must be equal to or greater than 0, but got -1" :: Nil + ) + + errorTest( + "OFFSET clause is outermost node", + testRelation.offset(Literal(10, IntegerType)), + "Only the OFFSET clause is allowed in the LIMIT clause, but the OFFSET" :: + "clause is found to be the outermost node." :: Nil + ) + + errorTest( + "OFFSET clause in other node", + testRelation2.offset(Literal(10, IntegerType)).where('b > 1), + "Only the OFFSET clause is allowed in the LIMIT clause, but the OFFSET" :: + "clause found in: Filter." :: Nil + ) + + errorTest( + "the sum of num_rows in limit clause and num_rows in offset clause less than Int.MaxValue", + testRelation.offset(Literal(2000000000, IntegerType)).limit(Literal(1000000000, IntegerType)), + "The sum of limit and offset must not be greater than Int.MaxValue" :: Nil + ) + errorTest( "more than one generators in SELECT", listRelation.select(Explode($"list"), Explode($"list")), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index dfe790dca54d8..fa9b6193f5a85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -60,6 +60,22 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = windowsStats) } + test("offset estimation: offset < child's rowCount") { + val offset = Offset(Literal(2), plan) + checkStats(offset, Statistics(sizeInBytes = 96, rowCount = Some(8))) + } + + test("offset estimation: offset > child's rowCount") { + val offset = Offset(Literal(20), plan) + checkStats(offset, Statistics(sizeInBytes = 1, rowCount = Some(0))) + } + + test("offset estimation: offset = 0") { + val offset = Offset(Literal(0), plan) + // Offset is equal to zero, so Offset's stats is equal to its child's stats. + checkStats(offset, plan.stats.copy(attributeStats = AttributeMap(Nil))) + } + test("limit estimation: limit < child's rowCount") { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ba3d83714c302..5d0d45026043e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -82,22 +82,48 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ReturnAnswer(rootPlan) => rootPlan match { case Limit(IntegerLiteral(limit), Sort(order, true, child)) if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + TakeOrderedAndProjectExec(limit, 0, order, child.output, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + TakeOrderedAndProjectExec(limit, 0, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - CollectLimitExec(limit, planLater(child)) :: Nil + CollectLimitExec(limit, 0, planLater(child)) :: Nil + case GlobalLimitAndOffset( + IntegerLiteral(limit), + IntegerLiteral(offset), + Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, offset, order, child.output, planLater(child)) :: Nil + case GlobalLimitAndOffset( + IntegerLiteral(limit), + IntegerLiteral(offset), + Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, offset, order, projectList, planLater(child)) :: Nil + case GlobalLimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), child) => + CollectLimitExec(limit, offset, planLater(child)) :: Nil case Tail(IntegerLiteral(limit), child) => CollectTailExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case Limit(IntegerLiteral(limit), Sort(order, true, child)) if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + TakeOrderedAndProjectExec(limit, 0, order, child.output, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + TakeOrderedAndProjectExec(limit, 0, order, projectList, planLater(child)) :: Nil + case GlobalLimitAndOffset( + IntegerLiteral(limit), + IntegerLiteral(offset), + Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, offset, order, child.output, planLater(child)) :: Nil + case GlobalLimitAndOffset( + IntegerLiteral(limit), + IntegerLiteral(offset), + Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, offset, order, projectList, planLater(child)) :: Nil case _ => Nil } } @@ -691,6 +717,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.LocalLimitExec(limit, planLater(child)) :: Nil case logical.GlobalLimit(IntegerLiteral(limit), child) => execution.GlobalLimitExec(limit, planLater(child)) :: Nil + case logical.GlobalLimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), child) => + execution.GlobalLimitAndOffsetExec(limit, offset, planLater(child)) :: Nil case union: logical.Union => execution.UnionExec(union.children.map(planLater)) :: Nil case g @ logical.Generate(generator, _, outer, _, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ddbd0a343ffcf..aa7ce64ab9354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -22,6 +22,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.plans.logical.Limit import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -41,10 +42,12 @@ trait LimitExec extends UnaryExecNode { * This operator will be used when a logical `Limit` operation is the final operator in an * logical plan, which happens when the user is collecting results back to the driver. */ -case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec { +case class CollectLimitExec(limit: Int, offset: Int, child: SparkPlan) extends LimitExec { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition - override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + override def executeCollect(): Array[InternalRow] = { + child.executeTake(limit + offset).drop(offset) + } private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) @@ -52,7 +55,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec { SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { - val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) + val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit + offset)) val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( locallyLimited, @@ -61,7 +64,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec { serializer, writeMetrics), readMetrics) - shuffled.mapPartitionsInternal(_.take(limit)) + shuffled.mapPartitionsInternal(_.drop(offset).take(limit)) } } @@ -115,7 +118,7 @@ trait BaseLimitExec extends LimitExec with CodegenSupport { // to the parent operator. override def usedInputs: AttributeSet = AttributeSet.empty - private lazy val countTerm = BaseLimitExec.newLimitCountTerm() + protected lazy val countTerm = BaseLimitExec.newLimitCountTerm() override lazy val limitNotReachedChecks: Seq[String] = { s"$countTerm < $limit" +: super.limitNotReachedChecks @@ -161,6 +164,45 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { override def outputOrdering: Seq[SortOrder] = child.outputOrdering } +/** + * Skip the first `offset` elements then take the first `limit` of the following elements in + * the child's single output partition. + */ +case class GlobalLimitAndOffsetExec( + limit: Int, + offset: Int, + child: SparkPlan) extends BaseLimitExec { + + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doExecute(): RDD[InternalRow] = { + val rdd = child.execute().mapPartitions { iter => iter.take(limit + offset)} + rdd.zipWithIndex().filter(_._2 >= offset).map(_._1) + } + + private lazy val skipTerm = BaseLimitExec.newLimitCountTerm() + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + // The counter name is already obtained by the upstream operators via `limitNotReachedChecks`. + // Here we have to inline it to not change its name. This is fine as we won't have many limit + // operators in one query. + ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false) + ctx.addMutableState(CodeGenerator.JAVA_INT, skipTerm, forceInline = true, useFreshName = false) + s""" + | if ($skipTerm < $offset) { + | $skipTerm += 1; + | } else if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } + """.stripMargin + } +} + /** * Take the first limit elements as defined by the sortOrder, and do projection if needed. * This is logically equivalent to having a Limit operator after a [[SortExec]] operator, @@ -170,6 +212,7 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { */ case class TakeOrderedAndProjectExec( limit: Int, + offset: Int, sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { @@ -180,7 +223,7 @@ case class TakeOrderedAndProjectExec( override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) - val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + val data = child.execute().map(_.copy()).takeOrdered(limit + offset)(ord).drop(offset) if (projectList != child.output) { val proj = UnsafeProjection.create(projectList, child.output) data.map(r => proj(r).copy()) @@ -201,7 +244,7 @@ case class TakeOrderedAndProjectExec( val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val localTopK: RDD[InternalRow] = { child.execute().map(_.copy()).mapPartitions { iter => - org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord) + org.apache.spark.util.collection.Utils.takeOrdered(iter, limit + offset)(ord) } } val shuffled = new ShuffledRowRDD( @@ -213,7 +256,8 @@ case class TakeOrderedAndProjectExec( writeMetrics), readMetrics) shuffled.mapPartitions { iter => - val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + val topK = org.apache.spark.util.collection.Utils.takeOrdered( + iter.map(_.copy()), limit + offset)(ord).drop(offset) if (projectList != child.output) { val proj = UnsafeProjection.create(projectList, child.output) topK.map(r => proj(r)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/limit.sql index bc0b5d6dddc52..5d2b211175f26 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/limit.sql @@ -12,25 +12,24 @@ SELECT '' AS five, unique1, unique2, stringu1 SELECT '' AS two, unique1, unique2, stringu1 FROM onek WHERE unique1 > 60 AND unique1 < 63 ORDER BY unique1 LIMIT 5; --- [SPARK-28330] ANSI SQL: Top-level in --- SELECT '' AS three, unique1, unique2, stringu1 --- FROM onek WHERE unique1 > 100 --- ORDER BY unique1 LIMIT 3 OFFSET 20; --- SELECT '' AS zero, unique1, unique2, stringu1 --- FROM onek WHERE unique1 < 50 --- ORDER BY unique1 DESC LIMIT 8 OFFSET 99; --- SELECT '' AS eleven, unique1, unique2, stringu1 --- FROM onek WHERE unique1 < 50 --- ORDER BY unique1 DESC LIMIT 20 OFFSET 39; +SELECT '' AS three, unique1, unique2, stringu1 + FROM onek WHERE unique1 > 100 + ORDER BY unique1 LIMIT 3 OFFSET 20; +SELECT '' AS zero, unique1, unique2, stringu1 + FROM onek WHERE unique1 < 50 + ORDER BY unique1 DESC LIMIT 8 OFFSET 99; +SELECT '' AS eleven, unique1, unique2, stringu1 + FROM onek WHERE unique1 < 50 + ORDER BY unique1 DESC LIMIT 20 OFFSET 39; -- SELECT '' AS ten, unique1, unique2, stringu1 -- FROM onek -- ORDER BY unique1 OFFSET 990; -- SELECT '' AS five, unique1, unique2, stringu1 -- FROM onek -- ORDER BY unique1 OFFSET 990 LIMIT 5; --- SELECT '' AS five, unique1, unique2, stringu1 --- FROM onek --- ORDER BY unique1 LIMIT 5 OFFSET 900; +SELECT '' AS five, unique1, unique2, stringu1 + FROM onek + ORDER BY unique1 LIMIT 5 OFFSET 900; CREATE OR REPLACE TEMPORARY VIEW INT8_TBL AS SELECT * FROM (VALUES @@ -45,8 +44,7 @@ CREATE OR REPLACE TEMPORARY VIEW INT8_TBL AS SELECT * FROM -- constant, so to ensure executor is exercised, do this: -- [SPARK-29650] Discard a NULL constant in LIMIT select * from int8_tbl limit (case when random() < 0.5 then bigint(null) end); --- [SPARK-28330] ANSI SQL: Top-level in --- select * from int8_tbl offset (case when random() < 0.5 then bigint(null) end); +select * from int8_tbl offset (case when random() < 0.5 then bigint(null) end); -- Test assorted cases involving backwards fetch from a LIMIT plan node -- [SPARK-20965] Support PREPARE/EXECUTE/DECLARE/FETCH statements @@ -90,7 +88,6 @@ DROP VIEW INT8_TBL; -- Stress test for variable LIMIT in conjunction with bounded-heap sorting --- [SPARK-28330] ANSI SQL: Top-level in -- SELECT -- (SELECT n -- FROM (VALUES (1)) AS x, diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/limit.sql.out index 2c8bc31dbc6ca..79e33f54020e1 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/limit.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 12 -- !query @@ -38,6 +38,62 @@ struct 62 633 KCAAAA +-- !query +SELECT '' AS three, unique1, unique2, stringu1 + FROM onek WHERE unique1 > 100 + ORDER BY unique1 LIMIT 3 OFFSET 20 +-- !query schema +struct +-- !query output + 121 700 REAAAA + 122 519 SEAAAA + 123 777 TEAAAA + + +-- !query +SELECT '' AS zero, unique1, unique2, stringu1 + FROM onek WHERE unique1 < 50 + ORDER BY unique1 DESC LIMIT 8 OFFSET 99 +-- !query schema +struct +-- !query output + + + +-- !query +SELECT '' AS eleven, unique1, unique2, stringu1 + FROM onek WHERE unique1 < 50 + ORDER BY unique1 DESC LIMIT 20 OFFSET 39 +-- !query schema +struct +-- !query output + 10 520 KAAAAA + 9 49 JAAAAA + 8 653 IAAAAA + 7 647 HAAAAA + 6 978 GAAAAA + 5 541 FAAAAA + 4 833 EAAAAA + 3 431 DAAAAA + 2 326 CAAAAA + 1 214 BAAAAA + 0 998 AAAAAA + + +-- !query +SELECT '' AS five, unique1, unique2, stringu1 + FROM onek + ORDER BY unique1 LIMIT 5 OFFSET 900 +-- !query schema +struct +-- !query output + 900 913 QIAAAA + 901 931 RIAAAA + 902 702 SIAAAA + 903 641 TIAAAA + 904 793 UIAAAA + + -- !query CREATE OR REPLACE TEMPORARY VIEW INT8_TBL AS SELECT * FROM (VALUES @@ -62,6 +118,15 @@ org.apache.spark.sql.AnalysisException The limit expression must evaluate to a constant value, but got CASE WHEN (`_nondeterministic` < CAST(0.5BD AS DOUBLE)) THEN CAST(NULL AS BIGINT) END; +-- !query +select * from int8_tbl offset (case when random() < 0.5 then bigint(null) end) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +The offset expression must evaluate to a constant value, but got CASE WHEN (`_nondeterministic` < CAST(0.5BD AS DOUBLE)) THEN CAST(NULL AS BIGINT) END; + + -- !query DROP VIEW INT8_TBL -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 376d330ebeb70..ef5695e581fc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { checkThatPlansAgree( generateRandomInputData(), input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + noOpFilter(TakeOrderedAndProjectExec(limit, 0, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, @@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { generateRandomInputData(), input => noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + TakeOrderedAndProjectExec(limit, 0, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit,