@@ -725,40 +725,59 @@ case class HashAggregateExec(
725725
726726 val thisPlan = ctx.addReferenceObj(" plan" , this )
727727
728- // Create a name for the iterator from the fast hash map, and the code to create fast hash map.
729- val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) {
730- // Generates the fast hash map class and creates the fast hash map term.
731- val fastHashMapClassName = ctx.freshName(" FastHashMap" )
732- if (isVectorizedHashMapEnabled) {
733- val generatedMap = new VectorizedHashMapGenerator (ctx, aggregateExpressions,
734- fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
735- ctx.addInnerClass(generatedMap)
736-
737- // Inline mutable state since not many aggregation operations in a task
738- fastHashMapTerm = ctx.addMutableState(
739- fastHashMapClassName, " vectorizedFastHashMap" , forceInline = true )
740- val iter = ctx.addMutableState(
741- " java.util.Iterator<InternalRow>" ,
742- " vectorizedFastHashMapIter" ,
743- forceInline = true )
744- val create = s " $fastHashMapTerm = new $fastHashMapClassName(); "
745- (iter, create)
746- } else {
747- val generatedMap = new RowBasedHashMapGenerator (ctx, aggregateExpressions,
748- fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
749- ctx.addInnerClass(generatedMap)
750-
751- // Inline mutable state since not many aggregation operations in a task
752- fastHashMapTerm = ctx.addMutableState(
753- fastHashMapClassName, " fastHashMap" , forceInline = true )
754- val iter = ctx.addMutableState(
755- " org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>" ,
756- " fastHashMapIter" , forceInline = true )
757- val create = s " $fastHashMapTerm = new $fastHashMapClassName( " +
758- s " $thisPlan.getTaskContext(), $thisPlan.getEmptyAggregationBuffer()); "
759- (iter, create)
760- }
761- } else (" " , " " )
728+ // Create a name for the iterator from the fast hash map, the code to create
729+ // and add hook to close fast hash map.
730+ val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) =
731+ if (isFastHashMapEnabled) {
732+ // Generates the fast hash map class and creates the fast hash map term.
733+ val fastHashMapClassName = ctx.freshName(" FastHashMap" )
734+ val (iter, create) = if (isVectorizedHashMapEnabled) {
735+ val generatedMap = new VectorizedHashMapGenerator (ctx, aggregateExpressions,
736+ fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
737+ ctx.addInnerClass(generatedMap)
738+
739+ // Inline mutable state since not many aggregation operations in a task
740+ fastHashMapTerm = ctx.addMutableState(
741+ fastHashMapClassName, " vectorizedFastHashMap" , forceInline = true )
742+ val iter = ctx.addMutableState(
743+ " java.util.Iterator<InternalRow>" ,
744+ " vectorizedFastHashMapIter" ,
745+ forceInline = true )
746+ val create = s " $fastHashMapTerm = new $fastHashMapClassName(); "
747+ (iter, create)
748+ } else {
749+ val generatedMap = new RowBasedHashMapGenerator (ctx, aggregateExpressions,
750+ fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
751+ ctx.addInnerClass(generatedMap)
752+
753+ // Inline mutable state since not many aggregation operations in a task
754+ fastHashMapTerm = ctx.addMutableState(
755+ fastHashMapClassName, " fastHashMap" , forceInline = true )
756+ val iter = ctx.addMutableState(
757+ " org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>" ,
758+ " fastHashMapIter" , forceInline = true )
759+ val create = s " $fastHashMapTerm = new $fastHashMapClassName( " +
760+ s " $thisPlan.getTaskContext().taskMemoryManager(), " +
761+ s " $thisPlan.getEmptyAggregationBuffer()); "
762+ (iter, create)
763+ }
764+
765+ // Generates the code to register a cleanup task with TaskContext to ensure that memory
766+ // is guaranteed to be freed at the end of the task. This is necessary to avoid memory
767+ // leaks in when the downstream operator does not fully consume the aggregation map's
768+ // output (e.g. aggregate followed by limit).
769+ val hookToCloseFastHashMap =
770+ s """
771+ | $thisPlan.getTaskContext().addTaskCompletionListener(
772+ | new org.apache.spark.util.TaskCompletionListener() {
773+ | @Override
774+ | public void onTaskCompletion(org.apache.spark.TaskContext context) {
775+ | $fastHashMapTerm.close();
776+ | }
777+ |});
778+ """ .stripMargin
779+ (iter, create, hookToCloseFastHashMap)
780+ } else (" " , " " , " " )
762781
763782 // Create a name for the iterator from the regular hash map.
764783 // Inline mutable state since not many aggregation operations in a task
@@ -877,6 +896,7 @@ case class HashAggregateExec(
877896 |if (! $initAgg) {
878897 | $initAgg = true;
879898 | $createFastHashMap
899+ | $addHookToCloseFastHashMap
880900 | $hashMapTerm = $thisPlan.createHashMap();
881901 | long $beforeAgg = System.nanoTime();
882902 | $doAggFuncName();
0 commit comments