Skip to content

Commit e0559f2

Browse files
committed
[SPARK-21743][SQL][FOLLOWUP] free aggregate map when task ends
## What changes were proposed in this pull request? This is the first follow-up of #21573 , which was only merged to 2.3. This PR fixes the memory leak in another way: free the `UnsafeExternalMap` when the task ends. All the data buffers in Spark SQL are using `UnsafeExternalMap` and `UnsafeExternalSorter` under the hood, e.g. sort, aggregate, window, SMJ, etc. `UnsafeExternalSorter` registers a task completion listener to free the resource, we should apply the same thing to `UnsafeExternalMap`. TODO in the next PR: do not consume all the inputs when having limit in whole stage codegen. ## How was this patch tested? existing tests Author: Wenchen Fan <[email protected]> Closes #21738 from cloud-fan/limit.
1 parent 6fe3286 commit e0559f2

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import java.io.IOException;
2121

2222
import org.apache.spark.SparkEnv;
23+
import org.apache.spark.TaskContext;
2324
import org.apache.spark.internal.config.package$;
24-
import org.apache.spark.memory.TaskMemoryManager;
2525
import org.apache.spark.sql.catalyst.InternalRow;
2626
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
2727
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -82,27 +82,34 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
8282
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
8383
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
8484
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
85-
* @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
85+
* @param taskContext the current task context.
8686
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
8787
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
8888
*/
8989
public UnsafeFixedWidthAggregationMap(
9090
InternalRow emptyAggregationBuffer,
9191
StructType aggregationBufferSchema,
9292
StructType groupingKeySchema,
93-
TaskMemoryManager taskMemoryManager,
93+
TaskContext taskContext,
9494
int initialCapacity,
9595
long pageSizeBytes) {
9696
this.aggregationBufferSchema = aggregationBufferSchema;
9797
this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
9898
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
9999
this.groupingKeySchema = groupingKeySchema;
100-
this.map =
101-
new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true);
100+
this.map = new BytesToBytesMap(
101+
taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true);
102102

103103
// Initialize the buffer for aggregation value
104104
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
105105
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
106+
107+
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
108+
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
109+
// does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
110+
taskContext.addTaskCompletionListener(context -> {
111+
free();
112+
});
106113
}
107114

108115
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
7373
if limit < conf.topKSortFallbackThreshold =>
7474
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
7575
case Limit(IntegerLiteral(limit), child) =>
76-
// With whole stage codegen, Spark releases resources only when all the output data of the
77-
// query plan are consumed. It's possible that `CollectLimitExec` only consumes a little
78-
// data from child plan and finishes the query without releasing resources. Here we wrap
79-
// the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and
80-
// trigger the resource releasing work, after we consume `limit` rows.
81-
CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
76+
CollectLimitExec(limit, planLater(child)) :: Nil
8277
case other => planLater(other) :: Nil
8378
}
8479
case Limit(IntegerLiteral(limit), Sort(order, true, child))

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ case class HashAggregateExec(
328328
initialBuffer,
329329
bufferSchema,
330330
groupingKeySchema,
331-
TaskContext.get().taskMemoryManager(),
331+
TaskContext.get(),
332332
1024 * 16, // initial capacity
333333
TaskContext.get().taskMemoryManager().pageSizeBytes
334334
)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class TungstenAggregationIterator(
166166
initialAggregationBuffer,
167167
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
168168
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
169-
TaskContext.get().taskMemoryManager(),
169+
TaskContext.get(),
170170
1024 * 16, // initial capacity
171171
TaskContext.get().taskMemoryManager().pageSizeBytes
172172
)

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.collection.mutable
2323
import scala.util.{Random, Try}
2424
import scala.util.control.NonFatal
2525

26+
import org.mockito.Mockito._
2627
import org.scalatest.Matchers
2728

2829
import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
@@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite
5455
private var memoryManager: TestMemoryManager = null
5556
private var taskMemoryManager: TaskMemoryManager = null
5657

58+
private var taskContext: TaskContext = null
59+
5760
def testWithMemoryLeakDetection(name: String)(f: => Unit) {
5861
def cleanup(): Unit = {
5962
if (taskMemoryManager != null) {
@@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite
6770
val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false")
6871
memoryManager = new TestMemoryManager(conf)
6972
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
73+
taskContext = mock(classOf[TaskContext])
74+
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
7075

7176
TaskContext.setTaskContext(new TaskContextImpl(
7277
stageId = 0,
@@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite
111116
emptyAggregationBuffer,
112117
aggBufferSchema,
113118
groupKeySchema,
114-
taskMemoryManager,
119+
taskContext,
115120
1024, // initial capacity,
116121
PAGE_SIZE_BYTES
117122
)
@@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
124129
emptyAggregationBuffer,
125130
aggBufferSchema,
126131
groupKeySchema,
127-
taskMemoryManager,
132+
taskContext,
128133
1024, // initial capacity
129134
PAGE_SIZE_BYTES
130135
)
@@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
151156
emptyAggregationBuffer,
152157
aggBufferSchema,
153158
groupKeySchema,
154-
taskMemoryManager,
159+
taskContext,
155160
128, // initial capacity
156161
PAGE_SIZE_BYTES
157162
)
@@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite
176181
emptyAggregationBuffer,
177182
aggBufferSchema,
178183
groupKeySchema,
179-
taskMemoryManager,
184+
taskContext,
180185
128, // initial capacity
181186
PAGE_SIZE_BYTES
182187
)
@@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite
223228
emptyAggregationBuffer,
224229
aggBufferSchema,
225230
groupKeySchema,
226-
taskMemoryManager,
231+
taskContext,
227232
128, // initial capacity
228233
PAGE_SIZE_BYTES
229234
)
@@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite
263268
emptyAggregationBuffer,
264269
StructType(Nil),
265270
StructType(Nil),
266-
taskMemoryManager,
271+
taskContext,
267272
128, // initial capacity
268273
PAGE_SIZE_BYTES
269274
)
@@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite
307312
emptyAggregationBuffer,
308313
aggBufferSchema,
309314
groupKeySchema,
310-
taskMemoryManager,
315+
taskContext,
311316
128, // initial capacity
312317
pageSize
313318
)
@@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite
344349
emptyAggregationBuffer,
345350
aggBufferSchema,
346351
groupKeySchema,
347-
taskMemoryManager,
352+
taskContext,
348353
128, // initial capacity
349354
pageSize
350355
)

0 commit comments

Comments
 (0)