diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index 2652f6d72730c..e0748043c46e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -35,13 +35,13 @@ trait LogicalPlanVisitor[T] { case p: LocalLimit => visitLocalLimit(p) case p: Pivot => visitPivot(p) case p: Project => visitProject(p) - case p: Range => visitRange(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) case p: ResolvedHint => visitHint(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) + case p: Window => visitWindow(p) case p: LogicalPlan => default(p) } @@ -73,8 +73,6 @@ trait LogicalPlanVisitor[T] { def visitProject(p: Project): T - def visitRange(p: Range): T - def visitRepartition(p: Repartition): T def visitRepartitionByExpr(p: RepartitionByExpression): T @@ -84,4 +82,6 @@ trait LogicalPlanVisitor[T] { def visitScriptTransform(p: ScriptTransformation): T def visitUnion(p: Union): T + + def visitWindow(p: Window): T } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 93908b04fb643..4cff72d45a400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -65,11 +65,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { ProjectEstimation.estimate(p).getOrElse(fallback(p)) } - override def visitRange(p: logical.Range): Statistics = { - val sizeInBytes = LongType.defaultSize * p.numElements - Statistics(sizeInBytes = sizeInBytes) - } - override def visitRepartition(p: Repartition): Statistics = fallback(p) override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) @@ -79,4 +74,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p) override def visitUnion(p: Union): Statistics = fallback(p) + + override def visitWindow(p: Window): Statistics = fallback(p) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 559f12072e448..d701a956887a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -136,10 +136,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitProject(p: Project): Statistics = visitUnaryNode(p) - override def visitRange(p: logical.Range): Statistics = { - p.computeStats() - } - override def visitRepartition(p: Repartition): Statistics = default(p) override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p) @@ -160,4 +156,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitUnion(p: Union): Statistics = { Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum) } + + override def visitWindow(p: Window): Statistics = visitUnaryNode(p) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 913be6d1ff07f..7d532ff343178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.statsEstimation +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -54,6 +56,24 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { ) } + test("range") { + val range = Range(1, 5, 1, None) + val rangeStats = Statistics(sizeInBytes = 4 * 8) + checkStats( + range, + expectedStatsCboOn = rangeStats, + expectedStatsCboOff = rangeStats) + } + + test("windows") { + val windows = plan.window(Seq(min(attribute).as('sum_attr)), Seq(attribute), Nil) + val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8)) + checkStats( + windows, + expectedStatsCboOn = windowsStats, + expectedStatsCboOff = windowsStats) + } + test("limit estimation: limit < child's rowCount") { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan)