Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ trait CodegenSupport extends SparkPlan {
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
outputVars.map(_.copy())
}

val rowVar = if (row != null) {
ExprCode("", "false", row)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ case class TungstenAggregate(
}
}

// This is for testing. We force TungstenAggregationIterator to fall back to sort-based
// aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[Int] = {
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
case null | "" => None
case fallbackStartsAt => Some(fallbackStartsAt.toInt)
case fallbackStartsAt =>
val splits = fallbackStartsAt.split(",").map(_.trim)
Some((splits.head.toInt, splits.last.toInt))
}
}

Expand Down Expand Up @@ -261,7 +263,15 @@ case class TungstenAggregate(
.map(_.asInstanceOf[DeclarativeAggregate])
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)

// The name for HashMap
// The name for Vectorized HashMap
private var vectorizedHashMapTerm: String = _

// We currently only enable vectorized hashmap for long key/value types and partial aggregates
private val isVectorizedHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled &&
(groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) &&
modes.forall(mode => mode == Partial || mode == PartialMerge)

// The name for UnsafeRow HashMap
private var hashMapTerm: String = _
private var sorterTerm: String = _

Expand Down Expand Up @@ -437,17 +447,18 @@ case class TungstenAggregate(
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")

// create AggregateHashMap
val isAggregateHashMapEnabled: Boolean = false
val isAggregateHashMapSupported: Boolean =
(groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap")
val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap")
val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, vectorizedHashMapClassName,
groupingKeySchema, bufferSchema)
if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")
// Create a name for iterator from vectorized HashMap
val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter")
if (isVectorizedHashMapEnabled) {
ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm,
s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();")
ctx.addMutableState(
"java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>",
iterTermForVectorizedHashMap, "")
}

// create hashMap
Expand All @@ -465,11 +476,14 @@ case class TungstenAggregate(
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
$hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}

${if (isVectorizedHashMapEnabled) {
s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""}

$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
}
""")
Expand All @@ -484,13 +498,43 @@ case class TungstenAggregate(
// so `copyResult` should be reset to `false`.
ctx.copyResult = false

// Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow
def outputFromGeneratedMap: Option[String] = {
if (isVectorizedHashMapEnabled) {
val row = ctx.freshName("vectorizedHashMapRow")
ctx.currentVars = null
ctx.INPUT_ROW = row
var schema: StructType = groupingKeySchema
bufferSchema.foreach(i => schema = schema.add(i))
val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
Option(
s"""
| while ($iterTermForVectorizedHashMap.hasNext()) {
| $numOutput.add(1);
| org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
| (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
| $iterTermForVectorizedHashMap.next();
| ${generateRow.code}
| ${consume(ctx, Seq.empty, {generateRow.value})}
|
| if (shouldStop()) return;
| }
|
| $vectorizedHashMapTerm.close();
""".stripMargin)
} else None
}

s"""
if (!$initAgg) {
$initAgg = true;
$doAgg();
}

// output the result
${outputFromGeneratedMap.getOrElse("")}

while ($iterTerm.next()) {
$numOutput.add(1);
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
Expand All @@ -511,10 +555,13 @@ case class TungstenAggregate(

// create grouping key
ctx.currentVars = input
val keyCode = GenerateUnsafeProjection.createCode(
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val key = keyCode.value
val buffer = ctx.freshName("aggBuffer")
val vectorizedRowKeys = ctx.generateExpressions(
groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer")

// only have DeclarativeAggregate
val updateExpr = aggregateExpressions.flatMap { e =>
Expand All @@ -533,56 +580,124 @@ case class TungstenAggregate(

val inputAttr = aggregateBufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
// TODO: support subexpression elimination
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
val updates = evals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
}

val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
val countTerm = ctx.freshName("fallbackCounter")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
(s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
(s"$countTerm < ${testFallbackStartsAt.get._1}",
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
("true", "", "")
("true", "true", "", "")
}

// We first generate code to probe and update the vectorized hash map. If the probe is
// successful the corresponding vectorized row buffer will hold the mutable row
val findOrInsertInVectorizedHashMap: Option[String] = {
if (isVectorizedHashMapEnabled) {
Option(
s"""
|if ($checkFallbackForGeneratedHashMap) {
| ${vectorizedRowKeys.map(_.code).mkString("\n")}
| if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert(
| ${vectorizedRowKeys.map(_.value).mkString(", ")});
| }
|}
""".stripMargin)
} else {
None
}
}

val updateRowInVectorizedHashMap: Option[String] = {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = vectorizedRowBuffer
val vectorizedRowEvals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable)
}
Option(
s"""
|// evaluate aggregate function
|${evaluateVariables(vectorizedRowEvals)}
|// update vectorized row
|${updateVectorizedRow.mkString("\n").trim}
""".stripMargin)
} else None
}

// Next, we generate code to probe and update the unsafe row hash map.
val findOrInsertInUnsafeRowMap: String = {
s"""
| if ($vectorizedRowBuffer == null) {
| // generate grouping key
| ${unsafeRowKeyCode.code.trim}
| ${hashEval.code.trim}
| if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
| }
| if ($unsafeRowBuffer == null) {
| if ($sorterTerm == null) {
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
| } else {
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
| }
| $resetCounter
| // the hash map had be spilled, it should have enough memory now,
| // try to allocate buffer again.
| $unsafeRowBuffer =
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
| if ($unsafeRowBuffer == null) {
| // failed to allocate the first page
| throw new OutOfMemoryError("No enough memory for aggregation");
| }
| }
| }
""".stripMargin
}

val updateRowInUnsafeRowMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val unsafeRowBufferEvals =
updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
}
s"""
|// evaluate aggregate function
|${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer
|${updateUnsafeRowBuffer.mkString("\n").trim}
""".stripMargin
}


// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
// hash map will return null for new key), we spill the hash map to disk to free memory, then
// continue to do in-memory aggregation and spilling until all the rows had been processed.
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code.trim}
${hashEval.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
}
if ($buffer == null) {
if ($sorterTerm == null) {
$sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
} else {
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
}
$resetCoulter
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
}
UnsafeRow $unsafeRowBuffer = null;
org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null;

${findOrInsertInVectorizedHashMap.getOrElse("")}

$findOrInsertInUnsafeRowMap

$incCounter

// evaluate aggregate function
${evaluateVariables(evals)}
// update aggregate buffer
${updates.mkString("\n").trim}
if ($vectorizedRowBuffer != null) {
// update vectorized row
${updateRowInVectorizedHashMap.getOrElse("")}
} else {
// update unsafe row
$updateRowInUnsafeRowMap
}
"""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TungstenAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int],
testFallbackStartsAt: Option[(Int, Int)],
numOutputRows: LongSQLMetric,
dataSize: LongSQLMetric,
spillSize: LongSQLMetric)
Expand Down Expand Up @@ -171,7 +171,7 @@ class TungstenAggregationIterator(
// hashMap. If there is not enough memory, it will multiple hash-maps, spilling
// after each becomes full then using sort to merge these spills, finally do sort
// based aggregation.
private def processInputs(fallbackStartsAt: Int): Unit = {
private def processInputs(fallbackStartsAt: (Int, Int)): Unit = {
if (groupingExpressions.isEmpty) {
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
// Note that it would be better to eliminate the hash map entirely in the future.
Expand All @@ -187,7 +187,7 @@ class TungstenAggregationIterator(
val newInput = inputIter.next()
val groupingKey = groupingProjection.apply(newInput)
var buffer: UnsafeRow = null
if (i < fallbackStartsAt) {
if (i < fallbackStartsAt._2) {
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
}
if (buffer == null) {
Expand Down Expand Up @@ -352,7 +352,7 @@ class TungstenAggregationIterator(
/**
* Start processing input rows.
*/
processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue))
processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue)))

// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
Expand Down
Loading