Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


Expand Down Expand Up @@ -436,7 +435,7 @@ case class CatalogRelation(
createTime = -1
))

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove import org.apache.spark.sql.internal.SQLConf, too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this file's import org.apache.spark.sql.internal.SQLConf need to be removed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I overlooked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I checked the other interface file, there are several files with this name...

// For data source tables, we will create a `LogicalRelation` and won't call this method, for
// hive serde tables, we will always generate a statistics.
// TODO: unify the table stats generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
// Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated.
if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
items.forall(_.stats(conf).rowCount.isDefined)) {
items.forall(_.stats.rowCount.isDefined)) {
JoinReorderDP.search(conf, items, conditions, output)
} else {
plan
Expand Down Expand Up @@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
/** Get the cost of the root node of this plan tree. */
def rootCost(conf: SQLConf): Cost = {
if (itemIds.size > 1) {
val rootStats = plan.stats(conf)
val rootStats = plan.stats
Cost(rootStats.rowCount.get, rootStats.sizeInBytes)
} else {
// If the plan is a leaf item, it has zero cost.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
case FullOuter =>
(left.maxRows, right.maxRows) match {
case (None, None) =>
if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) {
if (left.stats.sizeInBytes >= right.stats.sizeInBytes) {
join.copy(left = maybePushLimit(exp, left))
} else {
join.copy(right = maybePushLimit(exp, right))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
// Find if the input plans are eligible for star join detection.
// An eligible plan is a base table access with valid statistics.
val foundEligibleJoin = input.forall {
case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true
case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true
case _ => false
}

Expand Down Expand Up @@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan)
leafCol match {
case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf)
val stats = t.stats
stats.rowCount match {
case Some(rowCount) if rowCount >= 0 =>
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
Expand Down Expand Up @@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan)
leafCol match {
case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf)
val stats = t.stats
stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
case None => false
}
Expand Down Expand Up @@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
*/
private def getTableAccessCardinality(
input: LogicalPlan): Option[BigInt] = input match {
case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined =>
if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
Option(input.stats(conf).rowCount.get)
case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined =>
if (conf.cboEnabled && input.stats.rowCount.isDefined) {
Option(input.stats.rowCount.get)
} else {
Option(t.stats(conf).rowCount.get)
Option(t.stats.rowCount.get)
}
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
Expand Down Expand Up @@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
}
}

override def computeStats(conf: SQLConf): Statistics =
override def computeStats: Statistics =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove import org.apache.spark.sql.internal.SQLConf, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked all modified files and removed unused imports. Thanks!

Statistics(sizeInBytes =
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


Expand Down Expand Up @@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
* first time. If the configuration changes, the cache can be invalidated by calling
* [[invalidateStatsCache()]].
*/
final def stats(conf: SQLConf): Statistics = statsCache.getOrElse {
statsCache = Some(computeStats(conf))
final def stats: Statistics = statsCache.getOrElse {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

statsCache = Some(computeStats)
statsCache.get
}

Expand All @@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
*
* [[LeafNode]]s must override this.
*/
protected def computeStats(conf: SQLConf): Statistics = {
protected def computeStats: Statistics = {
if (children.isEmpty) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product)
Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
}

override def verboseStringWithSuffix: String = {
Expand Down Expand Up @@ -333,21 +332,21 @@ abstract class UnaryNode extends LogicalPlan {

override protected def validConstraints: Set[Expression] = child.constraints

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
// Assume there will be the same number of rows as child has.
var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize
var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
sizeInBytes = 1
}

// Don't propagate rowCount and attributeStats, since they are not estimated here.
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -64,11 +63,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (conf.cboEnabled) {
ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf))
ProjectEstimation.estimate(this).getOrElse(super.computeStats)
} else {
super.computeStats(conf)
super.computeStats
}
}
}
Expand Down Expand Up @@ -138,11 +137,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
child.constraints.union(predicates.toSet)
}

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (conf.cboEnabled) {
FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
FilterEstimation(this).estimate.getOrElse(super.computeStats)
} else {
super.computeStats(conf)
super.computeStats
}
}
}
Expand Down Expand Up @@ -191,13 +190,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
}
}

override def computeStats(conf: SQLConf): Statistics = {
val leftSize = left.stats(conf).sizeInBytes
val rightSize = right.stats(conf).sizeInBytes
override def computeStats: Statistics = {
val leftSize = left.stats.sizeInBytes
val rightSize = right.stats.sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
Statistics(
sizeInBytes = sizeInBytes,
hints = left.stats(conf).hints.resetForJoin())
hints = left.stats.hints.resetForJoin())
}
}

Expand All @@ -208,8 +207,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le

override protected def validConstraints: Set[Expression] = leftConstraints

override def computeStats(conf: SQLConf): Statistics = {
left.stats(conf).copy()
override def computeStats: Statistics = {
left.stats.copy()
}
}

Expand Down Expand Up @@ -247,8 +246,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
children.length > 1 && childrenResolved && allChildrenCompatible
}

override def computeStats(conf: SQLConf): Statistics = {
val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum
override def computeStats: Statistics = {
val sizeInBytes = children.map(_.stats.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
}

Expand Down Expand Up @@ -356,20 +355,20 @@ case class Join(
case _ => resolvedExceptNatural
}

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
def simpleEstimation: Statistics = joinType match {
case LeftAnti | LeftSemi =>
// LeftSemi and LeftAnti won't ever be bigger than left
left.stats(conf)
left.stats
case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size.
val stats = super.computeStats(conf)
val stats = super.computeStats
stats.copy(hints = stats.hints.resetForJoin())
}

if (conf.cboEnabled) {
JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation)
JoinEstimation.estimate(this).getOrElse(simpleEstimation)
} else {
simpleEstimation
}
Expand Down Expand Up @@ -522,7 +521,7 @@ case class Range(

override def newInstance(): Range = copy(output = output.map(_.newInstance()))

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes )
}
Expand Down Expand Up @@ -555,20 +554,20 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg))
}

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
def simpleEstimation: Statistics = {
if (groupingExpressions.isEmpty) {
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1),
hints = child.stats(conf).hints)
hints = child.stats.hints)
} else {
super.computeStats(conf)
super.computeStats
}
}

if (conf.cboEnabled) {
AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation)
AggregateEstimation.estimate(this).getOrElse(simpleEstimation)
} else {
simpleEstimation
}
Expand Down Expand Up @@ -671,8 +670,8 @@ case class Expand(
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))

override def computeStats(conf: SQLConf): Statistics = {
val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length
override def computeStats: Statistics = {
val sizeInBytes = super.computeStats.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}

Expand Down Expand Up @@ -742,9 +741,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
case _ => None
}
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf)
val childStats = child.stats
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
// Don't propagate column stats, because we don't know the distribution after a limit operation
Statistics(
Expand All @@ -762,9 +761,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
case _ => None
}
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf)
val childStats = child.stats
if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
Expand Down Expand Up @@ -819,9 +818,9 @@ case class Sample(

override def output: Seq[Attribute] = child.output

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val ratio = upperBound - lowerBound
val childStats = child.stats(conf)
val childStats = child.stats
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
Expand Down Expand Up @@ -885,7 +884,7 @@ case class RepartitionByExpression(
case object OneRowRelation extends LeafNode {
override def maxRows: Option[Long] = Some(1)
override def output: Seq[Attribute] = Nil
override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1)
override def computeStats: Statistics = Statistics(sizeInBytes = 1)
}

/** A logical plan for `dropDuplicates`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf

/**
* A general hint for the child that is not yet resolved. This node is generated by the parser and
Expand All @@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())

override lazy val canonicalized: LogicalPlan = child.canonicalized

override def computeStats(conf: SQLConf): Statistics = {
val stats = child.stats(conf)
override def computeStats: Statistics = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

val stats = child.stats
stats.copy(hints = hints)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
import org.apache.spark.sql.internal.SQLConf


object AggregateEstimation {
Expand All @@ -29,13 +28,13 @@ object AggregateEstimation {
* Estimate the number of output rows based on column stats of group-by columns, and propagate
* column stats for aggregate expressions.
*/
def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats(conf)
def estimate(agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats
// Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e =>
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
}
if (rowCountsExist(conf, agg.child) && colStatsExist) {
if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
Expand Down
Loading