Skip to content

Commit 917e7bb

Browse files
committed
Move fast hash map cleanup logic to HashAggregateExec
1 parent c9d09be commit 917e7bb

File tree

2 files changed

+56
-48
lines changed

2 files changed

+56
-48
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -725,40 +725,59 @@ case class HashAggregateExec(
725725

726726
val thisPlan = ctx.addReferenceObj("plan", this)
727727

728-
// Create a name for the iterator from the fast hash map, and the code to create fast hash map.
729-
val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) {
730-
// Generates the fast hash map class and creates the fast hash map term.
731-
val fastHashMapClassName = ctx.freshName("FastHashMap")
732-
if (isVectorizedHashMapEnabled) {
733-
val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
734-
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
735-
ctx.addInnerClass(generatedMap)
736-
737-
// Inline mutable state since not many aggregation operations in a task
738-
fastHashMapTerm = ctx.addMutableState(
739-
fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
740-
val iter = ctx.addMutableState(
741-
"java.util.Iterator<InternalRow>",
742-
"vectorizedFastHashMapIter",
743-
forceInline = true)
744-
val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
745-
(iter, create)
746-
} else {
747-
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
748-
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
749-
ctx.addInnerClass(generatedMap)
750-
751-
// Inline mutable state since not many aggregation operations in a task
752-
fastHashMapTerm = ctx.addMutableState(
753-
fastHashMapClassName, "fastHashMap", forceInline = true)
754-
val iter = ctx.addMutableState(
755-
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
756-
"fastHashMapIter", forceInline = true)
757-
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
758-
s"$thisPlan.getTaskContext(), $thisPlan.getEmptyAggregationBuffer());"
759-
(iter, create)
760-
}
761-
} else ("", "")
728+
// Create a name for the iterator from the fast hash map, the code to create
729+
// and add hook to close fast hash map.
730+
val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) =
731+
if (isFastHashMapEnabled) {
732+
// Generates the fast hash map class and creates the fast hash map term.
733+
val fastHashMapClassName = ctx.freshName("FastHashMap")
734+
val (iter, create) = if (isVectorizedHashMapEnabled) {
735+
val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
736+
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
737+
ctx.addInnerClass(generatedMap)
738+
739+
// Inline mutable state since not many aggregation operations in a task
740+
fastHashMapTerm = ctx.addMutableState(
741+
fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
742+
val iter = ctx.addMutableState(
743+
"java.util.Iterator<InternalRow>",
744+
"vectorizedFastHashMapIter",
745+
forceInline = true)
746+
val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
747+
(iter, create)
748+
} else {
749+
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
750+
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
751+
ctx.addInnerClass(generatedMap)
752+
753+
// Inline mutable state since not many aggregation operations in a task
754+
fastHashMapTerm = ctx.addMutableState(
755+
fastHashMapClassName, "fastHashMap", forceInline = true)
756+
val iter = ctx.addMutableState(
757+
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
758+
"fastHashMapIter", forceInline = true)
759+
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
760+
s"$thisPlan.getTaskContext().taskMemoryManager(), " +
761+
s"$thisPlan.getEmptyAggregationBuffer());"
762+
(iter, create)
763+
}
764+
765+
// Generates the code to register a cleanup task with TaskContext to ensure that memory
766+
// is guaranteed to be freed at the end of the task. This is necessary to avoid memory
767+
// leaks in when the downstream operator does not fully consume the aggregation map's
768+
// output (e.g. aggregate followed by limit).
769+
val hookToCloseFastHashMap =
770+
s"""
771+
|$thisPlan.getTaskContext().addTaskCompletionListener(
772+
| new org.apache.spark.util.TaskCompletionListener() {
773+
| @Override
774+
| public void onTaskCompletion(org.apache.spark.TaskContext context) {
775+
| $fastHashMapTerm.close();
776+
| }
777+
|});
778+
""".stripMargin
779+
(iter, create, hookToCloseFastHashMap)
780+
} else ("", "", "")
762781

763782
// Create a name for the iterator from the regular hash map.
764783
// Inline mutable state since not many aggregation operations in a task
@@ -877,6 +896,7 @@ case class HashAggregateExec(
877896
|if (!$initAgg) {
878897
| $initAgg = true;
879898
| $createFastHashMap
899+
| $addHookToCloseFastHashMap
880900
| $hashMapTerm = $thisPlan.createHashMap();
881901
| long $beforeAgg = System.nanoTime();
882902
| $doAggFuncName();

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ class RowBasedHashMapGenerator(
7070
|
7171
|
7272
| public $generatedClassName(
73-
| org.apache.spark.TaskContext taskContext,
73+
| org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
7474
| InternalRow emptyAggregationBuffer) {
7575
| batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
76-
| .allocate($keySchema, $valueSchema, taskContext.taskMemoryManager(), capacity);
76+
| .allocate($keySchema, $valueSchema, taskMemoryManager, capacity);
7777
|
7878
| final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema);
7979
| final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
@@ -87,18 +87,6 @@ class RowBasedHashMapGenerator(
8787
|
8888
| buckets = new int[numBuckets];
8989
| java.util.Arrays.fill(buckets, -1);
90-
|
91-
| // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be
92-
| // freed at the end of the task. This is necessary to avoid memory leaks in when the
93-
| // downstream operator does not fully consume the aggregation map's output
94-
| // (e.g. aggregate followed by limit).
95-
| taskContext.addTaskCompletionListener(
96-
| new org.apache.spark.util.TaskCompletionListener() {
97-
| @Override
98-
| public void onTaskCompletion(org.apache.spark.TaskContext context) {
99-
| close();
100-
| }
101-
| });
10290
| }
10391
""".stripMargin
10492
}

0 commit comments

Comments
 (0)