Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ class Analyzer(
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}.transform {
// We are duplicating aggregates that are now computing a different value for each
// pivot value.
// TODO: Don't construct the physical container until after analysis.
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
Expand Down Expand Up @@ -1152,11 +1157,11 @@ class Analyzer(

// Extract Windowed AggregateExpression
case we @ WindowExpression(
AggregateExpression(function, mode, isDistinct),
ae @ AggregateExpression(function, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val newAgg = AggregateExpression(newFunction, mode, isDistinct)
val newAgg = ae.copy(aggregateFunction = newFunction)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ trait CheckAnalysis {
case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")

case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")

case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
package object errors {

class TreeNodeException[TreeType <: TreeNode[_]](
tree: TreeType, msg: String, cause: Throwable)
@transient val tree: TreeType,
msg: String,
cause: Throwable)
extends Exception(msg, cause) {

val treeString = tree.toString

// Yes, this is the same as a default parameter, but... those don't seem to work with SBT
// external project dependencies for some reason.
def this(tree: TreeType, msg: String) = this(tree, msg, null)

override def getMessage: String = {
val treeString = tree.toString
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -66,17 +67,51 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}

object AggregateExpression {
def apply(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
isDistinct: Boolean): AggregateExpression = {
AggregateExpression(
aggregateFunction,
mode,
isDistinct,
NamedExpression.newExprId)
}
}

/**
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
*/
private[sql] case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
isDistinct: Boolean)
isDistinct: Boolean,
resultId: ExprId)
extends Expression
with Unevaluable {

lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
AttributeReference(
aggregateFunction.toString,
aggregateFunction.dataType,
aggregateFunction.nullable)(exprId = resultId)
} else {
// This is a bit of a hack. Really we should not be constructing this container and reasoning
// about datatypes / aggregation mode until after we have finished analysis and made it to
// planning.
UnresolvedAttribute(aggregateFunction.toString)
}

// We compute the same thing regardless of our final result.
override lazy val canonicalized: Expression =
AggregateExpression(
aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
mode,
isDistinct,
ExprId(0))

override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ case class PrettyAttribute(
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
override def qualifier: Option[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
override def nullable: Boolean = throw new UnsupportedOperationException
override def nullable: Boolean = true
}

object VirtualColumn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
Expand All @@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
ae.copy(aggregateFunction = Count(Literal(1)))

// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
Expand Down Expand Up @@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)

case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
Expand Down Expand Up @@ -216,3 +217,75 @@ object IntegerIndex {
case _ => None
}
}

/**
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
* aggregation, the following transformations are performed:
* - Unnamed grouping expressions are named so that they can be referred to across phases of
* aggregation
* - Aggregations that appear multiple times are deduplicated.
* - The compution of the aggregations themselves is separated from the final result. For example,
* the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
* computation that computes `count.resultAttribute + 1`.
*/
object PhysicalAggregation {
// groupingExpressions, aggregateExpressions, resultExpressions, child
type ReturnType =
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of the distinct aggregate expressions and build a function which can
// be used to re-write expressions so that they reference the single copy of the
// aggregate function which actually gets computed.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression => agg
}
}.distinct

val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap

// The original `resultExpressions` are a set of expressions which may reference
// aggregate expressions, grouping column values, and constants. When aggregate operator
// emits output rows, we will use `resultExpressions` to generate an output projection
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
// so replace each aggregate expression by its corresponding attribute in the set:
ae.resultAttribute
case expression =>
// Since we're using `namedGroupingAttributes` to extract the grouping key
// columns, we need to replace grouping key expressions with their corresponding
// attributes. We do not rely on the equality check at here since attributes may
// differ cosmetically. Instead, we use semanticEquals.
groupExpressionMap.collectFirst {
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}

Some((
namedGroupingExpressions.map(_._2),
aggregateExpressions,
rewrittenResultExpressions,
child))

case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._

Expand All @@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
Alias(a.child, a.name)(exprId = ExprId(0))
case ae: AggregateExpression =>
ae.copy(resultId = ExprId(0))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}

/**
* The primary workflow for executing relational queries using Spark. Designed to allow easy
Expand All @@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
*/
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sqlContext.sessionState.planner

def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
case e: AnalysisException =>
val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
Expand All @@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
planner.plan(ReturnAnswer(optimizedPlan)).next()
}

// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)

/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()

/**
* Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
* row format conversions as needed.
*/
protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
}

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
PlanSubqueries(sqlContext),
EnsureRequirements(sqlContext.conf),
CollapseCodegenStages(sqlContext.conf),
ReuseExchange(sqlContext.conf))

protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan {
override def producedAttributes: AttributeSet = outputSet
}

object UnaryNode {
def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
case _ => None
}
}

private[sql] trait UnaryNode extends SparkPlan {
def child: SparkPlan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf
class SparkPlanner(
val sparkContext: SparkContext,
val conf: SQLConf,
val experimentalMethods: ExperimentalMethods)
val extraStrategies: Seq[Strategy])
extends SparkStrategies {

def numPartitions: Int = conf.numShufflePartitions

def strategies: Seq[Strategy] =
experimentalMethods.extraStrategies ++ (
extraStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
Expand Down
Loading