diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 3d2443ca959a..56cf78d8b7fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -48,6 +48,12 @@ class RowBasedHashMapGenerator( val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) + val numVarLenFields = groupingKeys.map(_.dataType).count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; | private int[] buckets; @@ -60,6 +66,7 @@ class RowBasedHashMapGenerator( | private long emptyVOff; | private int emptyVLen; | private boolean isBatchFull = false; + | private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; | | | public $generatedClassName( @@ -75,6 +82,9 @@ class RowBasedHashMapGenerator( | emptyVOff = Platform.BYTE_ARRAY_OFFSET; | emptyVLen = emptyBuffer.length; | + | agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); | } @@ -112,12 +122,6 @@ class RowBasedHashMapGenerator( * */ protected def generateFindOrInsert(): String = { - val numVarLenFields = groupingKeys.map(_.dataType).count { - case dt if UnsafeRow.isFixedLength(dt) => false - // TODO: consider large decimal and interval type - case _ => true - } - val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => key.dataType match { case t: DecimalType => @@ -130,6 +134,12 @@ class RowBasedHashMapGenerator( } }.mkString(";\n") + val resetNullBits = if (groupingKeySchema.map(_.nullable).forall(_ == false)) { + "" + } else { + "agg_rowWriter.zeroOutNullBytes();" + } + s""" |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${ groupingKeySignature}) { @@ -140,12 +150,8 @@ class RowBasedHashMapGenerator( | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { - | // creating the unsafe for new entry - | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter - | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | ${groupingKeySchema.length}, ${numVarLenFields * 32}); - | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed - | agg_rowWriter.zeroOutNullBytes(); + | agg_rowWriter.reset(); + | $resetNullBits | ${createUnsafeRowForKey}; | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result | = agg_rowWriter.getRow();