Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ trait LogicalPlanVisitor[T] {
case p: LocalLimit => visitLocalLimit(p)
case p: Pivot => visitPivot(p)
case p: Project => visitProject(p)
case p: Range => visitRange(p)
case p: Repartition => visitRepartition(p)
case p: RepartitionByExpression => visitRepartitionByExpr(p)
case p: ResolvedHint => visitHint(p)
case p: Sample => visitSample(p)
case p: ScriptTransformation => visitScriptTransform(p)
case p: Union => visitUnion(p)
case p: Window => visitWindow(p)
case p: LogicalPlan => default(p)
}

Expand Down Expand Up @@ -73,8 +73,6 @@ trait LogicalPlanVisitor[T] {

def visitProject(p: Project): T

def visitRange(p: Range): T

def visitRepartition(p: Repartition): T

def visitRepartitionByExpr(p: RepartitionByExpression): T
Expand All @@ -84,4 +82,6 @@ trait LogicalPlanVisitor[T] {
def visitScriptTransform(p: ScriptTransformation): T

def visitUnion(p: Union): T

def visitWindow(p: Window): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
ProjectEstimation.estimate(p).getOrElse(fallback(p))
}

override def visitRange(p: logical.Range): Statistics = {
val sizeInBytes = LongType.defaultSize * p.numElements
Statistics(sizeInBytes = sizeInBytes)
}

override def visitRepartition(p: Repartition): Statistics = fallback(p)

override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p)
Expand All @@ -79,4 +74,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p)

override def visitUnion(p: Union): Statistics = fallback(p)

override def visitWindow(p: Window): Statistics = fallback(p)
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {

override def visitProject(p: Project): Statistics = visitUnaryNode(p)

override def visitRange(p: logical.Range): Statistics = {
p.computeStats()
}

override def visitRepartition(p: Repartition): Statistics = default(p)

override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p)
Expand All @@ -160,4 +156,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitUnion(p: Union): Statistics = {
Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum)
}

override def visitWindow(p: Window): Statistics = visitUnaryNode(p)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -54,6 +56,24 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
)
}

test("range") {
val range = Range(1, 5, 1, None)
val rangeStats = Statistics(sizeInBytes = 4 * 8)
checkStats(
range,
expectedStatsCboOn = rangeStats,
expectedStatsCboOff = rangeStats)
}

test("windows") {
val windows = plan.window(Seq(min(attribute).as('sum_attr)), Seq(attribute), Nil)
val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8))
checkStats(
windows,
expectedStatsCboOn = windowsStats,
expectedStatsCboOff = windowsStats)
}

test("limit estimation: limit < child's rowCount") {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
Expand Down