Skip to content

Commit f49284b

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-7076][SPARK-7077][SPARK-7080][SQL] Use managed memory for aggregations
This patch adds managed-memory-based aggregation to Spark SQL / DataFrames. Instead of working with Java objects, this new aggregation path uses `sun.misc.Unsafe` to manipulate raw memory. This reduces the memory footprint for aggregations, resulting in fewer spills, OutOfMemoryErrors, and garbage collection pauses. As a result, this allows for higher memory utilization. It can also result in better cache locality since objects will be stored closer together in memory. This feature can be eanbled by setting `spark.sql.unsafe.enabled=true`. For now, this feature is only supported when codegen is enabled and only supports aggregations for which the grouping columns are primitive numeric types or strings and aggregated values are numeric. ### Managing memory with sun.misc.Unsafe This patch supports both on- and off-heap managed memory. - In on-heap mode, memory addresses are identified by the combination of a base Object and an offset within that object. - In off-heap mode, memory is addressed directly with 64-bit long addresses. To support both modes, functions that manipulate memory accept both `baseObject` and `baseOffset` fields. In off-heap mode, we simply pass `null` as `baseObject`. We allocate memory in large chunks, so memory fragmentation and allocation speed are not significant bottlenecks. By default, we use on-heap mode. To enable off-heap mode, set `spark.unsafe.offHeap=true`. To track allocated memory, this patch extends `SparkEnv` with an `ExecutorMemoryManager` and supplies each `TaskContext` with a `TaskMemoryManager`. These classes work together to track allocations and detect memory leaks. ### Compact tuple format This patch introduces `UnsafeRow`, a compact row layout. In this format, each tuple has three parts: a null bit set, fixed length values, and variable-length values: ![image](https://cloud.githubusercontent.com/assets/50748/7328538/2fdb65ce-ea8b-11e4-9743-6c0f02bb7d1f.png) - Rows are always 8-byte word aligned (so their sizes will always be a multiple of 8 bytes) - The bit set is used for null tracking: - Position _i_ is set if and only if field _i_ is null - The bit set is aligned to an 8-byte word boundary. - Every field appears as an 8-byte word in the fixed-length values part: - If a field is null, we zero out the values. - If a field is variable-length, the word stores a relative offset (w.r.t. the base of the tuple) that points to the beginning of the field's data in the variable-length part. - Each variable-length data type can have its own encoding: - For strings, the first word stores the length of the string and is followed by UTF-8 encoded bytes. If necessary, the end of the string is padded with empty bytes in order to ensure word-alignment. For example, a tuple that consists 3 fields of type (int, string, string), with value (null, “data”, “bricks”) would look like this: ![image](https://cloud.githubusercontent.com/assets/50748/7328526/1e21959c-ea8b-11e4-9a28-a4350fe4a7b5.png) This format allows us to compare tuples for equality by directly comparing their raw bytes. This also enables fast hashing of tuples. ### Hash map for performing aggregations This patch introduces `UnsafeFixedWidthAggregationMap`, a hash map for performing aggregations where the aggregation result columns are fixed-with. This map's keys and values are `Row` objects. `UnsafeFixedWidthAggregationMap` is implemented on top of `BytesToBytesMap`, an append-only map which supports byte-array keys and values. `BytesToBytesMap` stores pointers to key and value tuples. For each record with a new key, we copy the key and create the aggregation value buffer for that key and put them in a buffer. The hash table then simply stores pointers to the key and value. For each record with an existing key, we simply run the aggregation function to update the values in place. This map is implemented using open hashing with triangular sequence probing. Each entry stores two words in a long array: the first word stores the address of the key and the second word stores the relative offset from the key tuple to the value tuple, as well as the key's 32-bit hashcode. By storing the full hashcode, we reduce the number of equality checks that need to be performed to handle position collisions ()since the chance of hashcode collision is much lower than position collision). `UnsafeFixedWidthAggregationMap` allows regular Spark SQL `Row` objects to be used when probing the map. Internally, it encodes these rows into `UnsafeRow` format using `UnsafeRowConverter`. This conversion has a small overhead that can be eliminated in the future once we use UnsafeRows in other operators. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/5725) <!-- Reviewable:end --> Author: Josh Rosen <[email protected]> Closes apache#5725 from JoshRosen/unsafe and squashes the following commits: eeee512 [Josh Rosen] Add converters for Null, Boolean, Byte, and Short columns. 81f34f8 [Josh Rosen] Follow 'place children last' convention for GeneratedAggregate 1bc36cc [Josh Rosen] Refactor UnsafeRowConverter to avoid unnecessary boxing. 017b2dc [Josh Rosen] Remove BytesToBytesMap.finalize() 50e9671 [Josh Rosen] Throw memory leak warning even in case of error; add warning about code duplication 70a39e4 [Josh Rosen] Split MemoryManager into ExecutorMemoryManager and TaskMemoryManager: 6e4b192 [Josh Rosen] Remove an unused method from ByteArrayMethods. de5e001 [Josh Rosen] Fix debug vs. trace in logging message. a19e066 [Josh Rosen] Rename unsafe Java test suites to match Scala test naming convention. 78a5b84 [Josh Rosen] Add logging to MemoryManager ce3c565 [Josh Rosen] More comments, formatting, and code cleanup. 529e571 [Josh Rosen] Measure timeSpentResizing in nanoseconds instead of milliseconds. 3ca84b2 [Josh Rosen] Only zero the used portion of groupingKeyConversionScratchSpace 162caf7 [Josh Rosen] Fix test compilation b45f070 [Josh Rosen] Don't redundantly store the offset from key to value, since we can compute this from the key size. a8e4a3f [Josh Rosen] Introduce MemoryManager interface; add to SparkEnv. 0925847 [Josh Rosen] Disable MiMa checks for new unsafe module cde4132 [Josh Rosen] Add missing pom.xml 9c19fc0 [Josh Rosen] Add configuration options for heap vs. offheap 6ffdaa1 [Josh Rosen] Null handling improvements in UnsafeRow. 31eaabc [Josh Rosen] Lots of TODO and doc cleanup. a95291e [Josh Rosen] Cleanups to string handling code afe8dca [Josh Rosen] Some Javadoc cleanup f3dcbfe [Josh Rosen] More mod replacement 854201a [Josh Rosen] Import and comment cleanup 06e929d [Josh Rosen] More warning cleanup ef6b3d3 [Josh Rosen] Fix a bunch of FindBugs and IntelliJ inspections 29a7575 [Josh Rosen] Remove debug logging 49aed30 [Josh Rosen] More long -> int conversion. b26f1d3 [Josh Rosen] Fix bug in murmur hash implementation. 765243d [Josh Rosen] Enable optional performance metrics for hash map. 23a440a [Josh Rosen] Bump up default hash map size 628f936 [Josh Rosen] Use ints intead of longs for indexing. 92d5a06 [Josh Rosen] Address a number of minor code review comments. 1f4b716 [Josh Rosen] Merge Unsafe code into the regular GeneratedAggregate, guarded by a configuration flag; integrate planner support and re-enable all tests. d85eeff [Josh Rosen] Add basic sanity test for UnsafeFixedWidthAggregationMap bade966 [Josh Rosen] Comment update (bumping to refresh GitHub cache...) b3eaccd [Josh Rosen] Extract aggregation map into its own class. d2bb986 [Josh Rosen] Update to implement new Row methods added upstream 58ac393 [Josh Rosen] Use UNSAFE allocator in GeneratedAggregate (TODO: make this configurable) 7df6008 [Josh Rosen] Optimizations related to zeroing out memory: c1b3813 [Josh Rosen] Fix bug in UnsafeMemoryAllocator.free(): 738fa33 [Josh Rosen] Add feature flag to guard UnsafeGeneratedAggregate c55bf66 [Josh Rosen] Free buffer once iterator has been fully consumed. 62ab054 [Josh Rosen] Optimize for fact that get() is only called on String columns. c7f0b56 [Josh Rosen] Reuse UnsafeRow pointer in UnsafeRowConverter ae39694 [Josh Rosen] Add finalizer as "cleanup method of last resort" c754ae1 [Josh Rosen] Now that the store*() contract has been stregthened, we can remove an extra lookup f764d13 [Josh Rosen] Simplify address + length calculation in Location. 079f1bf [Josh Rosen] Some clarification of the BytesToBytesMap.lookup() / set() contract. 1a483c5 [Josh Rosen] First version that passes some aggregation tests: fc4c3a8 [Josh Rosen] Sketch how the converters will be used in UnsafeGeneratedAggregate 53ba9b7 [Josh Rosen] Start prototyping Java Row -> UnsafeRow converters 1ff814d [Josh Rosen] Add reminder to free memory on iterator completion 8a8f9df [Josh Rosen] Add skeleton for GeneratedAggregate integration. 5d55cef [Josh Rosen] Add skeleton for Row implementation. f03e9c1 [Josh Rosen] Play around with Unsafe implementations of more string methods. ab68e08 [Josh Rosen] Begin merging the UTF8String implementations. 480a74a [Josh Rosen] Initial import of code from Databricks unsafe utils repo.
1 parent 1fd6ed9 commit f49284b

File tree

47 files changed

+3675
-18
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+3675
-18
lines changed

core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@
9595
<artifactId>spark-network-shuffle_${scala.binary.version}</artifactId>
9696
<version>${project.version}</version>
9797
</dependency>
98+
<dependency>
99+
<groupId>org.apache.spark</groupId>
100+
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
101+
<version>${project.version}</version>
102+
</dependency>
98103
<dependency>
99104
<groupId>net.java.dev.jets3t</groupId>
100105
<artifactId>jets3t</artifactId>

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
4040
import org.apache.spark.serializer.Serializer
4141
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
4242
import org.apache.spark.storage._
43+
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
4344
import org.apache.spark.util.{RpcUtils, Utils}
4445

4546
/**
@@ -69,6 +70,7 @@ class SparkEnv (
6970
val sparkFilesDir: String,
7071
val metricsSystem: MetricsSystem,
7172
val shuffleMemoryManager: ShuffleMemoryManager,
73+
val executorMemoryManager: ExecutorMemoryManager,
7274
val outputCommitCoordinator: OutputCommitCoordinator,
7375
val conf: SparkConf) extends Logging {
7476

@@ -382,6 +384,15 @@ object SparkEnv extends Logging {
382384
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
383385
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
384386

387+
val executorMemoryManager: ExecutorMemoryManager = {
388+
val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
389+
MemoryAllocator.UNSAFE
390+
} else {
391+
MemoryAllocator.HEAP
392+
}
393+
new ExecutorMemoryManager(allocator)
394+
}
395+
385396
val envInstance = new SparkEnv(
386397
executorId,
387398
rpcEnv,
@@ -398,6 +409,7 @@ object SparkEnv extends Logging {
398409
sparkFilesDir,
399410
metricsSystem,
400411
shuffleMemoryManager,
412+
executorMemoryManager,
401413
outputCommitCoordinator,
402414
conf)
403415

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.Serializable
2121

2222
import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.executor.TaskMetrics
24+
import org.apache.spark.unsafe.memory.TaskMemoryManager
2425
import org.apache.spark.util.TaskCompletionListener
2526

2627

@@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable {
133134
/** ::DeveloperApi:: */
134135
@DeveloperApi
135136
def taskMetrics(): TaskMetrics
137+
138+
/**
139+
* Returns the manager for this task's managed memory.
140+
*/
141+
private[spark] def taskMemoryManager(): TaskMemoryManager
136142
}

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import org.apache.spark.executor.TaskMetrics
21+
import org.apache.spark.unsafe.memory.TaskMemoryManager
2122
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
2223

2324
import scala.collection.mutable.ArrayBuffer
@@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
2728
val partitionId: Int,
2829
override val taskAttemptId: Long,
2930
override val attemptNumber: Int,
31+
override val taskMemoryManager: TaskMemoryManager,
3032
val runningLocally: Boolean = false,
3133
val taskMetrics: TaskMetrics = TaskMetrics.empty)
3234
extends TaskContext

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
3232
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
3333
import org.apache.spark.shuffle.FetchFailedException
3434
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
35+
import org.apache.spark.unsafe.memory.TaskMemoryManager
3536
import org.apache.spark.util._
3637

3738
/**
@@ -178,6 +179,7 @@ private[spark] class Executor(
178179
}
179180

180181
override def run(): Unit = {
182+
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
181183
val deserializeStartTime = System.currentTimeMillis()
182184
Thread.currentThread.setContextClassLoader(replClassLoader)
183185
val ser = env.closureSerializer.newInstance()
@@ -190,6 +192,7 @@ private[spark] class Executor(
190192
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
191193
updateDependencies(taskFiles, taskJars)
192194
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
195+
task.setTaskMemoryManager(taskMemoryManager)
193196

194197
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
195198
// continue executing the task.
@@ -206,7 +209,21 @@ private[spark] class Executor(
206209

207210
// Run the actual task and measure its runtime.
208211
taskStart = System.currentTimeMillis()
209-
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
212+
val value = try {
213+
task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
214+
} finally {
215+
// Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
216+
// when changing this, make sure to update both copies.
217+
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
218+
if (freedMemory > 0) {
219+
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
220+
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
221+
throw new SparkException(errMsg)
222+
} else {
223+
logError(errMsg)
224+
}
225+
}
226+
}
210227
val taskFinish = System.currentTimeMillis()
211228

212229
// If the task has been killed, let's fail it.

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics
3434
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
3535
import org.apache.spark.rdd.RDD
3636
import org.apache.spark.storage._
37+
import org.apache.spark.unsafe.memory.TaskMemoryManager
3738
import org.apache.spark.util._
3839
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
3940

@@ -643,15 +644,32 @@ class DAGScheduler(
643644
try {
644645
val rdd = job.finalStage.rdd
645646
val split = rdd.partitions(job.partitions(0))
646-
val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
647-
attemptNumber = 0, runningLocally = true)
647+
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
648+
val taskContext =
649+
new TaskContextImpl(
650+
job.finalStage.id,
651+
job.partitions(0),
652+
taskAttemptId = 0,
653+
attemptNumber = 0,
654+
taskMemoryManager = taskMemoryManager,
655+
runningLocally = true)
648656
TaskContext.setTaskContext(taskContext)
649657
try {
650658
val result = job.func(taskContext, rdd.iterator(split, taskContext))
651659
job.listener.taskSucceeded(0, result)
652660
} finally {
653661
taskContext.markTaskCompleted()
654662
TaskContext.unset()
663+
// Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
664+
// make sure to update both copies.
665+
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
666+
if (freedMemory > 0) {
667+
if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
668+
throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
669+
} else {
670+
logError(s"Managed memory leak detected; size = $freedMemory bytes")
671+
}
672+
}
655673
}
656674
} catch {
657675
case e: Exception =>

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap
2525
import org.apache.spark.{TaskContextImpl, TaskContext}
2626
import org.apache.spark.executor.TaskMetrics
2727
import org.apache.spark.serializer.SerializerInstance
28+
import org.apache.spark.unsafe.memory.TaskMemoryManager
2829
import org.apache.spark.util.ByteBufferInputStream
2930
import org.apache.spark.util.Utils
3031

@@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
5253
* @return the result of the task
5354
*/
5455
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
55-
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
56-
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
56+
context = new TaskContextImpl(
57+
stageId = stageId,
58+
partitionId = partitionId,
59+
taskAttemptId = taskAttemptId,
60+
attemptNumber = attemptNumber,
61+
taskMemoryManager = taskMemoryManager,
62+
runningLocally = false)
5763
TaskContext.setTaskContext(context)
5864
context.taskMetrics.setHostname(Utils.localHostName())
5965
taskThread = Thread.currentThread()
@@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
6874
}
6975
}
7076

77+
private var taskMemoryManager: TaskMemoryManager = _
78+
79+
def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
80+
this.taskMemoryManager = taskMemoryManager
81+
}
82+
7183
def runTask(context: TaskContext): T
7284

7385
def preferredLocations: Seq[TaskLocation] = Nil

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ public void persist() {
10091009
@Test
10101010
public void iterator() {
10111011
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
1012-
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
1012+
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
10131013
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
10141014
}
10151015

core/src/test/scala/org/apache/spark/CacheManagerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
6565
// in blockManager.put is a losing battle. You have been warned.
6666
blockManager = sc.env.blockManager
6767
cacheManager = sc.env.cacheManager
68-
val context = new TaskContextImpl(0, 0, 0, 0)
68+
val context = new TaskContextImpl(0, 0, 0, 0, null)
6969
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
7070
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
7171
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
7777
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
7878
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
7979

80-
val context = new TaskContextImpl(0, 0, 0, 0)
80+
val context = new TaskContextImpl(0, 0, 0, 0, null)
8181
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
8282
assert(value.toList === List(5, 6, 7))
8383
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
8686
// Local computation should not persist the resulting value, so don't expect a put().
8787
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
8888

89-
val context = new TaskContextImpl(0, 0, 0, 0, true)
89+
val context = new TaskContextImpl(0, 0, 0, 0, null, true)
9090
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
9191
assert(value.toList === List(1, 2, 3, 4))
9292
}
9393

9494
test("verify task metrics updated correctly") {
9595
cacheManager = sc.env.cacheManager
96-
val context = new TaskContextImpl(0, 0, 0, 0)
96+
val context = new TaskContextImpl(0, 0, 0, 0, null)
9797
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
9898
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
9999
}

core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
176176
}
177177
val hadoopPart1 = generateFakeHadoopPartition()
178178
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
179-
val tContext = new TaskContextImpl(0, 0, 0, 0)
179+
val tContext = new TaskContextImpl(0, 0, 0, 0, null)
180180
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
181181
val arr = rddIter.toArray
182182
assert(arr(0) == "/some/path")

0 commit comments

Comments
 (0)