-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14447][SQL] Speed up TungstenAggregate w/ keys using VectorizedHashMap #12345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
7c158bd
ebaea6a
cee7e65
8c9e17a
3379294
4ee5687
c2fc385
fc6b8cb
ececd57
0ca0db1
041c001
ec66a54
9b5ee1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext | |
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * This is a helper object to generate an append-only single-key/single value aggregate hash | ||
| * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates | ||
| * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in | ||
| * TungstenAggregate to speed up aggregates w/ key. | ||
| * This is a helper class to generate an append-only aggregate hash map that can act as a 'cache' | ||
| * for extremely fast key-value lookups while evaluating aggregates (and fall back to the | ||
| * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in TungstenAggregate to speed | ||
| * up aggregates w/ key. | ||
| * | ||
| * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the | ||
| * key-value pairs. The index lookups in the array rely on linear probing (with a small number of | ||
|
|
@@ -65,27 +65,43 @@ class ColumnarAggMapCodeGenerator( | |
| .mkString("\n")}; | ||
| """.stripMargin | ||
|
|
||
| val generatedAggBufferSchema: String = | ||
| s""" | ||
| |new org.apache.spark.sql.types.StructType() | ||
| |${bufferSchema.map(key => | ||
| s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") | ||
| .mkString("\n")}; | ||
| """.stripMargin | ||
|
|
||
| s""" | ||
| | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; | ||
| | public org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; | ||
|
||
| | public org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; | ||
| | private int[] buckets; | ||
| | private int numBuckets; | ||
| | private int maxSteps; | ||
| | private int numRows = 0; | ||
| | private org.apache.spark.sql.types.StructType schema = $generatedSchema | ||
| | private org.apache.spark.sql.types.StructType aggregateBufferSchema = | ||
| | $generatedAggBufferSchema | ||
| | | ||
| | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { | ||
| | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); | ||
| | this.maxSteps = maxSteps; | ||
| | numBuckets = (int) (capacity / loadFactor); | ||
| | public $generatedClassName() { | ||
| | // TODO: These should be generated based on the schema | ||
| | int DEFAULT_CAPACITY = 1 << 16; | ||
| | double DEFAULT_LOAD_FACTOR = 0.25; | ||
| | int DEFAULT_MAX_STEPS = 2; | ||
| | assert (DEFAULT_CAPACITY > 0 && ((DEFAULT_CAPACITY & (DEFAULT_CAPACITY - 1)) == 0)); | ||
| | this.maxSteps = DEFAULT_MAX_STEPS; | ||
| | numBuckets = (int) (DEFAULT_CAPACITY / DEFAULT_LOAD_FACTOR); | ||
| | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, | ||
| | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); | ||
| | org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY); | ||
| | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( | ||
| | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's leave a TODO to fix this. There should be a nicer way to get a projection of a batch instead of this. |
||
| | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { | ||
| | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length})); | ||
| | } | ||
| | buckets = new int[numBuckets]; | ||
| | java.util.Arrays.fill(buckets, -1); | ||
| | } | ||
| | | ||
| | public $generatedClassName() { | ||
| | new $generatedClassName(1 << 16, 0.25, 5); | ||
| | } | ||
| """.stripMargin | ||
| } | ||
|
|
||
|
|
@@ -103,7 +119,7 @@ class ColumnarAggMapCodeGenerator( | |
| s""" | ||
| |// TODO: Improve this hash function | ||
| |private long hash($groupingKeySignature) { | ||
| | return ${groupingKeys.map(_._2).mkString(" ^ ")}; | ||
| | return ${groupingKeys.map(_._2).mkString(" | ")}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason not do implmeent the h = h * 37 + v hash function?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No particular reason, was planning to do this as part of a separate small PR (along with the benchmarks). Please let me know if you'd prefer it here instead |
||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
@@ -178,9 +194,10 @@ class ColumnarAggMapCodeGenerator( | |
| s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);") | ||
| .mkString("\n")} | ||
| | buckets[idx] = numRows++; | ||
| | return batch.getRow(buckets[idx]); | ||
| | batch.setNumRows(numRows); | ||
| | return aggregateBufferBatch.getRow(buckets[idx]); | ||
| | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { | ||
| | return batch.getRow(buckets[idx]); | ||
| | return aggregateBufferBatch.getRow(buckets[idx]); | ||
| | } | ||
| | idx = (idx + 1) & (numBuckets - 1); | ||
| | step++; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,6 +261,10 @@ case class TungstenAggregate( | |
| .map(_.asInstanceOf[DeclarativeAggregate]) | ||
| private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) | ||
|
|
||
| // The name for AggregateHashMap | ||
| private var aggregateHashMapTerm: String = _ | ||
| private var isAggregateHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled | ||
|
|
||
| // The name for HashMap | ||
| private var hashMapTerm: String = _ | ||
| private var sorterTerm: String = _ | ||
|
|
@@ -437,17 +441,21 @@ case class TungstenAggregate( | |
| val initAgg = ctx.freshName("initAgg") | ||
| ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") | ||
|
|
||
| // create AggregateHashMap | ||
| val isAggregateHashMapEnabled: Boolean = false | ||
| val isAggregateHashMapSupported: Boolean = | ||
| // We currently only enable aggregate hashmap for long key/value types | ||
| isAggregateHashMapEnabled = isAggregateHashMapEnabled && | ||
| (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) | ||
| val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") | ||
| aggregateHashMapTerm = ctx.freshName("aggregateHashMap") | ||
| val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") | ||
| val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName, | ||
| groupingKeySchema, bufferSchema) | ||
| if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { | ||
| // Create a name for iterator from AggregateHashMap | ||
| val iterTermForGeneratedHashMap = ctx.freshName("genMapIter") | ||
| if (isAggregateHashMapEnabled) { | ||
| ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, | ||
| s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") | ||
| ctx.addMutableState( | ||
| "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>", | ||
| iterTermForGeneratedHashMap, "") | ||
| } | ||
|
|
||
| // create hashMap | ||
|
|
@@ -465,11 +473,14 @@ case class TungstenAggregate( | |
| val doAgg = ctx.freshName("doAggregateWithKeys") | ||
| ctx.addNewFunction(doAgg, | ||
| s""" | ||
| ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} | ||
| ${if (isAggregateHashMapEnabled) aggregateHashMapGenerator.generate() else ""} | ||
| private void $doAgg() throws java.io.IOException { | ||
| $hashMapTerm = $thisPlan.createHashMap(); | ||
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
|
|
||
| ${if (isAggregateHashMapEnabled) { | ||
| s"$iterTermForGeneratedHashMap = $aggregateHashMapTerm.batch.rowIterator();"} else ""} | ||
|
|
||
| $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); | ||
| } | ||
| """) | ||
|
|
@@ -484,13 +495,42 @@ case class TungstenAggregate( | |
| // so `copyResult` should be reset to `false`. | ||
| ctx.copyResult = false | ||
|
|
||
| def outputFromGeneratedMap: Option[String] = { | ||
| if (isAggregateHashMapEnabled) { | ||
| val row = ctx.freshName("aggregateHashMapRow") | ||
|
||
| ctx.currentVars = null | ||
| ctx.INPUT_ROW = row | ||
| var schema: StructType = groupingKeySchema | ||
| bufferSchema.foreach(i => schema = schema.add(i)) | ||
| val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex | ||
| .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) | ||
| Option( | ||
| s""" | ||
| | while ($iterTermForGeneratedHashMap.hasNext()) { | ||
| | $numOutput.add(1); | ||
| | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = | ||
| | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) | ||
| | $iterTermForGeneratedHashMap.next(); | ||
| | ${generateRow.code} | ||
| | ${consume(ctx, Seq.empty, {generateRow.value})} | ||
| | | ||
| | if (shouldStop()) return; | ||
| | } | ||
| | | ||
| | $aggregateHashMapTerm.batch.close(); | ||
| """.stripMargin) | ||
| } else None | ||
| } | ||
|
|
||
| s""" | ||
| if (!$initAgg) { | ||
| $initAgg = true; | ||
| $doAgg(); | ||
| } | ||
|
|
||
| // output the result | ||
| ${outputFromGeneratedMap.getOrElse("")} | ||
|
|
||
| while ($iterTerm.next()) { | ||
| $numOutput.add(1); | ||
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | ||
|
|
@@ -513,8 +553,11 @@ case class TungstenAggregate( | |
| ctx.currentVars = input | ||
| val keyCode = GenerateUnsafeProjection.createCode( | ||
|
||
| ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) | ||
| val groupByKeys = ctx.generateExpressions( | ||
| groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) | ||
| val key = keyCode.value | ||
| val buffer = ctx.freshName("aggBuffer") | ||
| val aggregateRow = ctx.freshName("aggregateRow") | ||
|
|
||
| // only have DeclarativeAggregate | ||
| val updateExpr = aggregateExpressions.flatMap { e => | ||
|
|
@@ -533,56 +576,97 @@ case class TungstenAggregate( | |
|
|
||
| val inputAttr = aggregateBufferAttributes ++ child.output | ||
| ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input | ||
|
|
||
| ctx.INPUT_ROW = aggregateRow | ||
| // TODO: support subexpression elimination | ||
|
||
| val aggregateRowEvals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) | ||
|
||
| val updateAggregateRow = aggregateRowEvals.zipWithIndex.map { case (ev, i) => | ||
| val dt = updateExpr(i).dataType | ||
| ctx.updateColumn(aggregateRow, dt, i, ev, updateExpr(i).nullable) | ||
| } | ||
|
|
||
| ctx.INPUT_ROW = buffer | ||
| // TODO: support subexpression elimination | ||
| val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) | ||
| val updates = evals.zipWithIndex.map { case (ev, i) => | ||
| val aggregateBufferEvals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) | ||
| val updateAggregateBuffer = aggregateBufferEvals.zipWithIndex.map { case (ev, i) => | ||
| val dt = updateExpr(i).dataType | ||
| ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) | ||
| } | ||
|
|
||
| val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) { | ||
| val (checkFallback, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { | ||
| val countTerm = ctx.freshName("fallbackCounter") | ||
| ctx.addMutableState("int", countTerm, s"$countTerm = 0;") | ||
| (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;") | ||
| } else { | ||
| ("true", "", "") | ||
| } | ||
|
|
||
| val findOrInsertInGeneratedHashMap: Option[String] = { | ||
| if (isAggregateHashMapEnabled) { | ||
| Option( | ||
| s""" | ||
| | $aggregateRow = | ||
| | $aggregateHashMapTerm.findOrInsert(${groupByKeys.map(_.value).mkString(", ")}); | ||
| """.stripMargin) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
|
|
||
| val findOrInsertInBytesToBytesMap: String = { | ||
| s""" | ||
| | if ($aggregateRow == null) { | ||
| | // generate grouping key | ||
| | ${keyCode.code.trim} | ||
| | ${hashEval.code.trim} | ||
| | if ($checkFallback) { | ||
| | // try to get the buffer from hash map | ||
| | $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); | ||
| | } | ||
| | if ($buffer == null) { | ||
| | if ($sorterTerm == null) { | ||
| | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
| | } else { | ||
| | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
| | } | ||
| | $resetCounter | ||
| | // the hash map had be spilled, it should have enough memory now, | ||
| | // try to allocate buffer again. | ||
| | $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); | ||
| | if ($buffer == null) { | ||
| | // failed to allocate the first page | ||
| | throw new OutOfMemoryError("No enough memory for aggregation"); | ||
| | } | ||
| | } | ||
| | } | ||
| """.stripMargin | ||
| } | ||
|
|
||
| // We try to do hash map based in-memory aggregation first. If there is not enough memory (the | ||
| // hash map will return null for new key), we spill the hash map to disk to free memory, then | ||
| // continue to do in-memory aggregation and spilling until all the rows had been processed. | ||
| // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. | ||
| s""" | ||
| // generate grouping key | ||
| ${keyCode.code.trim} | ||
| ${hashEval.code.trim} | ||
| UnsafeRow $buffer = null; | ||
| if ($checkFallback) { | ||
| // try to get the buffer from hash map | ||
| $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); | ||
| } | ||
| if ($buffer == null) { | ||
| if ($sorterTerm == null) { | ||
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
| } else { | ||
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
| } | ||
| $resetCoulter | ||
| // the hash map had be spilled, it should have enough memory now, | ||
| // try to allocate buffer again. | ||
| $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); | ||
| if ($buffer == null) { | ||
| // failed to allocate the first page | ||
| throw new OutOfMemoryError("No enough memory for aggregation"); | ||
| } | ||
| } | ||
| org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $aggregateRow = null; | ||
|
|
||
| ${findOrInsertInGeneratedHashMap.getOrElse("")} | ||
|
|
||
| $findOrInsertInBytesToBytesMap | ||
|
|
||
| $incCounter | ||
|
|
||
| // evaluate aggregate function | ||
| ${evaluateVariables(evals)} | ||
| // update aggregate buffer | ||
| ${updates.mkString("\n").trim} | ||
| if ($aggregateRow != null) { | ||
| // evaluate aggregate function | ||
| ${evaluateVariables(aggregateRowEvals)} | ||
| // update aggregate row | ||
| ${updateAggregateRow.mkString("\n").trim} | ||
| } else { | ||
| // evaluate aggregate function | ||
| ${evaluateVariables(aggregateBufferEvals)} | ||
| // update aggregate buffer | ||
| ${updateAggregateBuffer.mkString("\n").trim} | ||
| } | ||
| """ | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You chose to have this not handle null keys right? Comment that.