Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
AttributeReference(s"MS[$i]", LongType)()
}

override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

/** Fill all words with zeros. */
override def initialize(buffer: MutableRow): Unit = {
var word = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*/
def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit

final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
// Extracts all distinct aggregate expressions from the resultExpressions.
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of the distinct aggregate expressions and build a function which can
// be used to re-write expressions so that they reference the single copy of the
// aggregate function which actually gets computed.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
}.distinct
// For those distinct aggregate expressions, we create a map from the
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) ->
(aggregateFunction -> attribtue)
val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) -> attribute
}.toMap

val (functionsWithDistinct, functionsWithoutDistinct) =
Expand All @@ -220,33 +223,67 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"code path.")
}

val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap

// The original `resultExpressions` are a set of expressions which may reference
// aggregate expressions, grouping column values, and constants. When aggregate operator
// emits output rows, we will use `resultExpressions` to generate an output projection
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case AggregateExpression2(aggregateFunction, _, isDistinct) =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
// so replace each aggregate expression by its corresponding attribute in the set:
aggregateFunctionToAttribute(aggregateFunction, isDistinct)
case expression =>
// Since we're using `namedGroupingAttributes` to extract the grouping key
// columns, we need to replace grouping key expressions with their corresponding
// attributes. We do not rely on the equality check at here since attributes may
// differ cosmetically. Instead, we use semanticEquals.
groupExpressionMap.collectFirst {
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}

val aggregateOperator =
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
if (functionsWithDistinct.nonEmpty) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.Utils.planAggregateWithoutPartial(
groupingExpressions,
namedGroupingExpressions.map(_._2),
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
aggregateFunctionToAttribute,
rewrittenResultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
groupingExpressions,
namedGroupingExpressions.map(_._2),
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
aggregateFunctionToAttribute,
rewrittenResultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
groupingExpressions,
namedGroupingExpressions.map(_._2),
functionsWithDistinct,
functionsWithoutDistinct,
aggregateFunctionMap,
resultExpressions,
aggregateFunctionToAttribute,
rewrittenResultExpressions,
planLater(child))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ abstract class AggregationIterator(
// Initializing the function used to generate the output row.
protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
val rowToBeEvaluated = new JoinedRow
val safeOutputRow = new GenericMutableRow(resultExpressions.length)
val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType))
val mutableOutput = if (outputsUnsafeRows) {
UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
} else {
Expand Down Expand Up @@ -359,15 +359,16 @@ abstract class AggregationIterator(
val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
// TODO: Use unsafe row.
val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
expressionAggEvalProjection.target(aggregateResult)
val resultProjection =
newMutableProjection(
resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
resultProjection.target(mutableOutput)

(currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
// Generate results for all expression-based aggregate functions.
expressionAggEvalProjection.target(aggregateResult)(currentBuffer)
expressionAggEvalProjection(currentBuffer)
// Generate results for all imperative aggregate functions.
var i = 0
while (i < allImperativeAggregateFunctions.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
Expand Down Expand Up @@ -77,7 +80,10 @@ case class TungstenAggregate(
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
Expand Down
Loading