Skip to content

Commit 70a39e4

Browse files
committed
Split MemoryManager into ExecutorMemoryManager and TaskMemoryManager:
- Implement memory leak detection, with exception vs. logging controlled by a configuration option.
1 parent 6e4b192 commit 70a39e4

File tree

21 files changed

+259
-60
lines changed

21 files changed

+259
-60
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
4040
import org.apache.spark.serializer.Serializer
4141
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
4242
import org.apache.spark.storage._
43-
import org.apache.spark.unsafe.memory.{MemoryManager => UnsafeMemoryManager, MemoryAllocator}
43+
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
4444
import org.apache.spark.util.{RpcUtils, Utils}
4545

4646
/**
@@ -70,7 +70,7 @@ class SparkEnv (
7070
val sparkFilesDir: String,
7171
val metricsSystem: MetricsSystem,
7272
val shuffleMemoryManager: ShuffleMemoryManager,
73-
val unsafeMemoryManager: UnsafeMemoryManager,
73+
val executorMemoryManager: ExecutorMemoryManager,
7474
val outputCommitCoordinator: OutputCommitCoordinator,
7575
val conf: SparkConf) extends Logging {
7676

@@ -384,13 +384,13 @@ object SparkEnv extends Logging {
384384
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
385385
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
386386

387-
val unsafeMemoryManager: UnsafeMemoryManager = {
387+
val executorMemoryManager: ExecutorMemoryManager = {
388388
val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
389389
MemoryAllocator.UNSAFE
390390
} else {
391391
MemoryAllocator.HEAP
392392
}
393-
new UnsafeMemoryManager(allocator)
393+
new ExecutorMemoryManager(allocator)
394394
}
395395

396396
val envInstance = new SparkEnv(
@@ -409,7 +409,7 @@ object SparkEnv extends Logging {
409409
sparkFilesDir,
410410
metricsSystem,
411411
shuffleMemoryManager,
412-
unsafeMemoryManager,
412+
executorMemoryManager,
413413
outputCommitCoordinator,
414414
conf)
415415

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.Serializable
2121

2222
import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.executor.TaskMetrics
24+
import org.apache.spark.unsafe.memory.TaskMemoryManager
2425
import org.apache.spark.util.TaskCompletionListener
2526

2627

@@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable {
133134
/** ::DeveloperApi:: */
134135
@DeveloperApi
135136
def taskMetrics(): TaskMetrics
137+
138+
/**
139+
* Returns the manager for this task's managed memory.
140+
*/
141+
private[spark] def taskMemoryManager(): TaskMemoryManager
136142
}

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import org.apache.spark.executor.TaskMetrics
21+
import org.apache.spark.unsafe.memory.TaskMemoryManager
2122
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
2223

2324
import scala.collection.mutable.ArrayBuffer
@@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
2728
val partitionId: Int,
2829
override val taskAttemptId: Long,
2930
override val attemptNumber: Int,
31+
override val taskMemoryManager: TaskMemoryManager,
3032
val runningLocally: Boolean = false,
3133
val taskMetrics: TaskMetrics = TaskMetrics.empty)
3234
extends TaskContext

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
3232
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
3333
import org.apache.spark.shuffle.FetchFailedException
3434
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
35+
import org.apache.spark.unsafe.memory.TaskMemoryManager
3536
import org.apache.spark.util._
3637

3738
/**
@@ -179,6 +180,7 @@ private[spark] class Executor(
179180
}
180181

181182
override def run(): Unit = {
183+
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
182184
val deserializeStartTime = System.currentTimeMillis()
183185
Thread.currentThread.setContextClassLoader(replClassLoader)
184186
val ser = env.closureSerializer.newInstance()
@@ -191,6 +193,7 @@ private[spark] class Executor(
191193
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
192194
updateDependencies(taskFiles, taskJars)
193195
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
196+
task.setTaskMemoryManager(taskMemoryManager)
194197

195198
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
196199
// continue executing the task.
@@ -207,7 +210,23 @@ private[spark] class Executor(
207210

208211
// Run the actual task and measure its runtime.
209212
taskStart = System.currentTimeMillis()
210-
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
213+
var succeeded: Boolean = false
214+
val value = try {
215+
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
216+
succeeded = true
217+
value
218+
} finally {
219+
// Release managed memory used by this task
220+
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
221+
if (succeeded && freedMemory > 0) {
222+
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
223+
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
224+
throw new SparkException(errMsg)
225+
} else {
226+
logError(errMsg)
227+
}
228+
}
229+
}
211230
val taskFinish = System.currentTimeMillis()
212231

213232
// If the task has been killed, let's fail it.

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics
3434
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
3535
import org.apache.spark.rdd.RDD
3636
import org.apache.spark.storage._
37+
import org.apache.spark.unsafe.memory.TaskMemoryManager
3738
import org.apache.spark.util._
3839
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
3940

@@ -643,15 +644,32 @@ class DAGScheduler(
643644
try {
644645
val rdd = job.finalStage.rdd
645646
val split = rdd.partitions(job.partitions(0))
646-
val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
647-
attemptNumber = 0, runningLocally = true)
647+
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
648+
val taskContext =
649+
new TaskContextImpl(
650+
job.finalStage.id,
651+
job.partitions(0),
652+
taskAttemptId = 0,
653+
attemptNumber = 0,
654+
taskMemoryManager = taskMemoryManager,
655+
runningLocally = true)
648656
TaskContext.setTaskContext(taskContext)
657+
var succeeded: Boolean = false
649658
try {
650659
val result = job.func(taskContext, rdd.iterator(split, taskContext))
660+
succeeded = true
651661
job.listener.taskSucceeded(0, result)
652662
} finally {
653663
taskContext.markTaskCompleted()
654664
TaskContext.unset()
665+
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
666+
if (succeeded && freedMemory > 0) {
667+
if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
668+
throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
669+
} else {
670+
logError(s"Managed memory leak detected; size = $freedMemory bytes")
671+
}
672+
}
655673
}
656674
} catch {
657675
case e: Exception =>

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap
2525
import org.apache.spark.{TaskContextImpl, TaskContext}
2626
import org.apache.spark.executor.TaskMetrics
2727
import org.apache.spark.serializer.SerializerInstance
28+
import org.apache.spark.unsafe.memory.TaskMemoryManager
2829
import org.apache.spark.util.ByteBufferInputStream
2930
import org.apache.spark.util.Utils
3031

@@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
5253
* @return the result of the task
5354
*/
5455
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
55-
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
56-
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
56+
context = new TaskContextImpl(
57+
stageId = stageId,
58+
partitionId = partitionId,
59+
taskAttemptId = taskAttemptId,
60+
attemptNumber = attemptNumber,
61+
taskMemoryManager = taskMemoryManager,
62+
runningLocally = false)
5763
TaskContext.setTaskContext(context)
5864
context.taskMetrics.setHostname(Utils.localHostName())
5965
taskThread = Thread.currentThread()
@@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
6874
}
6975
}
7076

77+
private var taskMemoryManager: TaskMemoryManager = _
78+
79+
def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
80+
this.taskMemoryManager = taskMemoryManager
81+
}
82+
7183
def runTask(context: TaskContext): T
7284

7385
def preferredLocations: Seq[TaskLocation] = Nil

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ public void persist() {
10091009
@Test
10101010
public void iterator() {
10111011
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
1012-
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
1012+
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
10131013
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
10141014
}
10151015

core/src/test/scala/org/apache/spark/CacheManagerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
6565
// in blockManager.put is a losing battle. You have been warned.
6666
blockManager = sc.env.blockManager
6767
cacheManager = sc.env.cacheManager
68-
val context = new TaskContextImpl(0, 0, 0, 0)
68+
val context = new TaskContextImpl(0, 0, 0, 0, null)
6969
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
7070
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
7171
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
7777
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
7878
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
7979

80-
val context = new TaskContextImpl(0, 0, 0, 0)
80+
val context = new TaskContextImpl(0, 0, 0, 0, null)
8181
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
8282
assert(value.toList === List(5, 6, 7))
8383
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
8686
// Local computation should not persist the resulting value, so don't expect a put().
8787
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
8888

89-
val context = new TaskContextImpl(0, 0, 0, 0, true)
89+
val context = new TaskContextImpl(0, 0, 0, 0, null, true)
9090
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
9191
assert(value.toList === List(1, 2, 3, 4))
9292
}
9393

9494
test("verify task metrics updated correctly") {
9595
cacheManager = sc.env.cacheManager
96-
val context = new TaskContextImpl(0, 0, 0, 0)
96+
val context = new TaskContextImpl(0, 0, 0, 0, null)
9797
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
9898
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
9999
}

core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
176176
}
177177
val hadoopPart1 = generateFakeHadoopPartition()
178178
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
179-
val tContext = new TaskContextImpl(0, 0, 0, 0)
179+
val tContext = new TaskContextImpl(0, 0, 0, 0, null)
180180
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
181181
val arr = rddIter.toArray
182182
assert(arr(0) == "/some/path")

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
5151
}
5252

5353
test("all TaskCompletionListeners should be called even if some fail") {
54-
val context = new TaskContextImpl(0, 0, 0, 0)
54+
val context = new TaskContextImpl(0, 0, 0, 0, null)
5555
val listener = mock(classOf[TaskCompletionListener])
5656
context.addTaskCompletionListener(_ => throw new Exception("blah"))
5757
context.addTaskCompletionListener(listener)

0 commit comments

Comments
 (0)