Skip to content

Commit a8e4a3f

Browse files
committed
Introduce MemoryManager interface; add to SparkEnv.
The configuration of HEAP vs UNSAFE is now done at the Spark core level. The translation of encoded 64-bit addresses into base object + offset pairs is now handled by MemoryManager, allowing this pointers to be safely passed between operators that exchange data pages.
1 parent 0925847 commit a8e4a3f

File tree

11 files changed

+265
-114
lines changed

11 files changed

+265
-114
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
4040
import org.apache.spark.serializer.Serializer
4141
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
4242
import org.apache.spark.storage._
43+
import org.apache.spark.unsafe.memory.{MemoryManager => UnsafeMemoryManager, MemoryAllocator}
4344
import org.apache.spark.util.{RpcUtils, Utils}
4445

4546
/**
@@ -69,6 +70,7 @@ class SparkEnv (
6970
val sparkFilesDir: String,
7071
val metricsSystem: MetricsSystem,
7172
val shuffleMemoryManager: ShuffleMemoryManager,
73+
val unsafeMemoryManager: UnsafeMemoryManager,
7274
val outputCommitCoordinator: OutputCommitCoordinator,
7375
val conf: SparkConf) extends Logging {
7476

@@ -382,6 +384,15 @@ object SparkEnv extends Logging {
382384
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
383385
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
384386

387+
val unsafeMemoryManager: UnsafeMemoryManager = {
388+
val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
389+
MemoryAllocator.UNSAFE
390+
} else {
391+
MemoryAllocator.HEAP
392+
}
393+
new UnsafeMemoryManager(allocator)
394+
}
395+
385396
val envInstance = new SparkEnv(
386397
executorId,
387398
rpcEnv,
@@ -398,6 +409,7 @@ object SparkEnv extends Logging {
398409
sparkFilesDir,
399410
metricsSystem,
400411
shuffleMemoryManager,
412+
unsafeMemoryManager,
401413
outputCommitCoordinator,
402414
conf)
403415

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import org.apache.spark.sql.types.StructType;
2626
import org.apache.spark.unsafe.PlatformDependent;
2727
import org.apache.spark.unsafe.map.BytesToBytesMap;
28-
import org.apache.spark.unsafe.memory.MemoryAllocator;
2928
import org.apache.spark.unsafe.memory.MemoryLocation;
29+
import org.apache.spark.unsafe.memory.MemoryManager;
3030

3131
/**
3232
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -102,23 +102,23 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
102102
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
103103
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
104104
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
105-
* @param allocator the memory allocator used to allocate our Unsafe memory structures.
105+
* @param groupingKeySchema the memory manager used to allocate our Unsafe memory structures.
106106
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
107107
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
108108
*/
109109
public UnsafeFixedWidthAggregationMap(
110110
Row emptyAggregationBuffer,
111111
StructType aggregationBufferSchema,
112112
StructType groupingKeySchema,
113-
MemoryAllocator allocator,
113+
MemoryManager memoryManager,
114114
int initialCapacity,
115115
boolean enablePerfMetrics) {
116116
this.emptyAggregationBuffer =
117117
convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
118118
this.aggregationBufferSchema = aggregationBufferSchema;
119119
this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
120120
this.groupingKeySchema = groupingKeySchema;
121-
this.map = new BytesToBytesMap(allocator, initialCapacity, enablePerfMetrics);
121+
this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
122122
this.enablePerfMetrics = enablePerfMetrics;
123123
}
124124

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,32 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.unsafe.memory.MemoryAllocator
21-
import org.scalatest.{FunSuite, Matchers}
20+
import org.apache.spark.unsafe.memory.{MemoryManager, MemoryAllocator}
21+
import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
2222

2323
import org.apache.spark.sql.types._
2424

25-
class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers {
25+
class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach {
2626

2727
import UnsafeFixedWidthAggregationMap._
2828

2929
private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
3030
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
3131
private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0))
3232

33+
private var memoryManager: MemoryManager = null
34+
35+
override def beforeEach(): Unit = {
36+
memoryManager = new MemoryManager(true)
37+
}
38+
39+
override def afterEach(): Unit = {
40+
if (memoryManager != null) {
41+
memoryManager.cleanUpAllPages()
42+
memoryManager = null
43+
}
44+
}
45+
3346
test("supported schemas") {
3447
assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
3548
assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
@@ -45,7 +58,7 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers {
4558
emptyAggregationBuffer,
4659
aggBufferSchema,
4760
groupKeySchema,
48-
MemoryAllocator.HEAP,
61+
memoryManager,
4962
1024,
5063
false
5164
)
@@ -58,7 +71,7 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers {
5871
emptyAggregationBuffer,
5972
aggBufferSchema,
6073
groupKeySchema,
61-
MemoryAllocator.HEAP,
74+
memoryManager,
6275
1024,
6376
false
6477
)

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ private[spark] object SQLConf {
3131
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
3232
val CODEGEN_ENABLED = "spark.sql.codegen"
3333
val UNSAFE_ENABLED = "spark.sql.unsafe.enabled"
34-
val UNSAFE_USE_OFF_HEAP = "spark.sql.unsafe.offHeap"
3534
val DIALECT = "spark.sql.dialect"
3635

3736
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
@@ -159,13 +158,6 @@ private[sql] class SQLConf extends Serializable {
159158
*/
160159
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean
161160

162-
/**
163-
* When set to true, Spark SQL will use off-heap memory allocation for managed memory operations.
164-
*
165-
* Defaults to false.
166-
*/
167-
private[spark] def unsafeUseOffHeap: Boolean = getConf(UNSAFE_USE_OFF_HEAP, "false").toBoolean
168-
169161
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
170162

171163
/**

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,8 +1013,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
10131013

10141014
def unsafeEnabled: Boolean = self.conf.unsafeEnabled
10151015

1016-
def unsafeUseOffHeap: Boolean = self.conf.unsafeUseOffHeap
1017-
10181016
def numPartitions: Int = self.conf.numShufflePartitions
10191017

10201018
def strategies: Seq[Strategy] =

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import org.apache.spark.SparkEnv
2021
import org.apache.spark.annotation.DeveloperApi
2122
import org.apache.spark.rdd.RDD
2223
import org.apache.spark.sql.catalyst.trees._
@@ -43,16 +44,14 @@ case class AggregateEvaluation(
4344
* @param aggregateExpressions expressions that are computed for each group.
4445
* @param child the input data source.
4546
* @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
46-
* @param useOffHeap whether to use off-heap allocation (only takes effect if unsafeEnabled=true)
4747
*/
4848
@DeveloperApi
4949
case class GeneratedAggregate(
5050
partial: Boolean,
5151
groupingExpressions: Seq[Expression],
5252
aggregateExpressions: Seq[NamedExpression],
5353
child: SparkPlan,
54-
unsafeEnabled: Boolean,
55-
useOffHeap: Boolean)
54+
unsafeEnabled: Boolean)
5655
extends UnaryNode {
5756

5857
override def requiredChildDistribution: Seq[Distribution] =
@@ -291,7 +290,7 @@ case class GeneratedAggregate(
291290
newAggregationBuffer(EmptyRow),
292291
aggregationBufferSchema,
293292
groupKeySchema,
294-
if (useOffHeap) MemoryAllocator.UNSAFE else MemoryAllocator.HEAP,
293+
SparkEnv.get.unsafeMemoryManager,
295294
1024 * 16,
296295
false
297296
)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
141141
groupingExpressions,
142142
partialComputation,
143143
planLater(child),
144-
unsafeEnabled,
145-
unsafeUseOffHeap),
146-
unsafeEnabled,
147-
unsafeUseOffHeap) :: Nil
144+
unsafeEnabled),
145+
unsafeEnabled) :: Nil
148146

149147
// Cases where some aggregate can not be codegened
150148
case PartialAggregation(

0 commit comments

Comments
 (0)