-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-9240] [SQL] Hybrid aggregate operator using unsafe row #7813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
3915bac
299008c
af32210
f60cc83
d2c45a0
3171f44
f52ee53
bd9282b
33b7022
533d5b2
7fcbd87
b1ea5cf
964f88b
21fd15f
0f1b06f
c9cf3b6
ba6afbc
74d93c5
e317e2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -161,6 +161,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { | |
| allAggregates(rewrittenAggregateExpressions)) && | ||
| codegenEnabled && | ||
| !canBeConvertedToNewAggregation(plan) => | ||
| logInfo(s"Using ${classOf[execution.GeneratedAggregate].getCanonicalName} as " + | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. either remove this, or change it to logDebug (preferably removing it)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and all the following log statements |
||
| s"the physical Aggregate Operator.") | ||
| execution.GeneratedAggregate( | ||
| partial = false, | ||
| namedGroupingAttributes, | ||
|
|
@@ -180,6 +182,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { | |
| groupingExpressions, | ||
| partialComputation, | ||
| child) if !canBeConvertedToNewAggregation(plan) => | ||
| logInfo(s"Using ${classOf[execution.Aggregate].getCanonicalName} as " + | ||
| s"the physical Aggregate Operator.") | ||
| execution.Aggregate( | ||
| partial = false, | ||
| namedGroupingAttributes, | ||
|
|
@@ -227,6 +231,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { | |
| converted match { | ||
| case None => Nil // Cannot convert to new aggregation code path. | ||
| case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => | ||
| logInfo(s"Using ${classOf[aggregate.Aggregate].getCanonicalName} as " + | ||
| s"the physical Aggregate Operator.") | ||
| // Extracts all distinct aggregate expressions from the resultExpressions. | ||
| val aggregateExpressions = resultExpressions.flatMap { expr => | ||
| expr.collect { | ||
|
|
@@ -386,6 +392,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { | |
| Nil | ||
| } else { | ||
| Utils.checkInvalidAggregateFunction2(a) | ||
| logInfo(s"Using ${classOf[execution.Aggregate].getCanonicalName} as " + | ||
| s"the physical Aggregate Operator.") | ||
| execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.aggregate | ||
|
|
||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.errors._ | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} | ||
| import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types | ||
| * of the grouping expressions and aggregate functions, it determines if it uses | ||
| * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to | ||
| * process input rows. | ||
| */ | ||
| case class Aggregate( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it'd be a lot more clear if you just separate this into
This operator is clearly doing the job of two different operators, just with a lot of if branches. It'd be more clear if those go into the planner. |
||
| requiredChildDistributionExpressions: Option[Seq[Expression]], | ||
| groupingExpressions: Seq[NamedExpression], | ||
| nonCompleteAggregateExpressions: Seq[AggregateExpression2], | ||
| nonCompleteAggregateAttributes: Seq[Attribute], | ||
| completeAggregateExpressions: Seq[AggregateExpression2], | ||
| completeAggregateAttributes: Seq[Attribute], | ||
| initialInputBufferOffset: Int, | ||
| resultExpressions: Seq[NamedExpression], | ||
| child: SparkPlan) | ||
| extends UnaryNode { | ||
|
|
||
| private[this] val allAggregateExpressions = | ||
| nonCompleteAggregateExpressions ++ completeAggregateExpressions | ||
|
|
||
| private[this] val hasNonAlgebricAggregateFunctions = | ||
| !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) | ||
|
|
||
| // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of | ||
| // grouping key and aggregation buffer is supported; and (3) all | ||
| // aggregate functions are algebraic. | ||
| private[this] val supportsHybridIterator: Boolean = { | ||
| val aggregationBufferSchema: StructType = | ||
| StructType.fromAttributes( | ||
| allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) | ||
| val groupKeySchema: StructType = | ||
| StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) | ||
|
|
||
| val schemaSupportsUnsafe: Boolean = | ||
| UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && | ||
| UnsafeProjection.canSupport(groupKeySchema) | ||
|
|
||
| // TODO: Use the hybrid iterator for non-algebric aggregate functions. | ||
| sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions | ||
| } | ||
|
|
||
| private[this] val hybridAggregateEnabled = sqlContext.conf.useHybridAggregate | ||
|
|
||
| // We need to use sorted input if we have grouping expressions, and | ||
| // we cannot use the hybrid iterator or the hybrid is disabled. | ||
| private[this] val requiresSortedInput: Boolean = { | ||
| groupingExpressions.nonEmpty && (!supportsHybridIterator || !hybridAggregateEnabled) | ||
| } | ||
|
|
||
| override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions | ||
|
|
||
| // If result expressions' data types are all fixed length, we generate unsafe rows | ||
| // (We have this requirement instead of check the result of UnsafeProjection.canSupport | ||
| // is because we use a mutable projection to generate the result). | ||
| override def outputsUnsafeRows: Boolean = { | ||
| // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) | ||
| // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix | ||
| // any issue we get. | ||
| false | ||
| } | ||
|
|
||
| override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) | ||
|
|
||
| override def requiredChildDistribution: List[Distribution] = { | ||
| requiredChildDistributionExpressions match { | ||
| case Some(exprs) if exprs.length == 0 => AllTuples :: Nil | ||
| case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil | ||
| case None => UnspecifiedDistribution :: Nil | ||
| } | ||
| } | ||
|
|
||
| override def requiredChildOrdering: Seq[Seq[SortOrder]] = { | ||
| if (requiresSortedInput) { | ||
| // TODO: We should not sort the input rows if they are just in reversed order. | ||
| groupingExpressions.map(SortOrder(_, Ascending)) :: Nil | ||
| } else { | ||
| Seq.fill(children.size)(Nil) | ||
| } | ||
| } | ||
|
|
||
| override def outputOrdering: Seq[SortOrder] = { | ||
| if (requiresSortedInput) { | ||
| // It is possible that the child.outputOrdering starts with the required | ||
| // ordering expressions (e.g. we require [a] as the sort expression and the | ||
| // child's outputOrdering is [a, b]). We can only guarantee the output rows | ||
| // are sorted by values of groupingExpressions. | ||
| groupingExpressions.map(SortOrder(_, Ascending)) | ||
| } else { | ||
| Nil | ||
| } | ||
| } | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { | ||
| child.execute().mapPartitions { iter => | ||
| // Because the constructor of an aggregation iterator will read at least the first row, | ||
| // we need to get the value of iter.hasNext first. | ||
| val hasInput = iter.hasNext | ||
| val useHybridIterator = | ||
| hasInput && | ||
| supportsHybridIterator && | ||
| groupingExpressions.nonEmpty && | ||
| hybridAggregateEnabled | ||
| if (useHybridIterator) { | ||
| UnsafeHybridAggregationIterator.createFromInputIterator( | ||
| groupingExpressions, | ||
| nonCompleteAggregateExpressions, | ||
| nonCompleteAggregateAttributes, | ||
| completeAggregateExpressions, | ||
| completeAggregateAttributes, | ||
| initialInputBufferOffset, | ||
| resultExpressions, | ||
| newMutableProjection _, | ||
| child.output, | ||
| iter, | ||
| outputsUnsafeRows) | ||
| } else { | ||
| if (!hasInput && groupingExpressions.nonEmpty) { | ||
| // This is a grouped aggregate and the input iterator is empty, | ||
| // so return an empty iterator. | ||
| Iterator[InternalRow]() | ||
| } else { | ||
| val outputIter = SortBasedAggregationIterator.createFromInputIterator( | ||
| groupingExpressions, | ||
| nonCompleteAggregateExpressions, | ||
| nonCompleteAggregateAttributes, | ||
| completeAggregateExpressions, | ||
| completeAggregateAttributes, | ||
| initialInputBufferOffset, | ||
| resultExpressions, | ||
| newMutableProjection _ , | ||
| newProjection _, | ||
| child.output, | ||
| iter, | ||
| outputsUnsafeRows) | ||
| if (!hasInput && groupingExpressions.isEmpty) { | ||
| // There is no input and there is no grouping expressions. | ||
| // We need to output a single row as the output. | ||
| Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) | ||
| } else { | ||
| outputIter | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def simpleString: String = { | ||
| val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { | ||
| classOf[UnsafeHybridAggregationIterator].getSimpleName | ||
| } else { | ||
| classOf[SortBasedAggregationIterator].getSimpleName | ||
| } | ||
|
|
||
| s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we want a config flag here?