Skip to content

Commit 02f3187

Browse files
committed
make gpu caching effective (without calling .cache)
1 parent d371450 commit 02f3187

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

core/src/main/scala/org/apache/spark/ColumnPartitionData.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class ColumnPartitionData[T](
8585
// Extracted to a function for use in deserialization
8686
private def initialize {
8787
pointers = schema.columns.map { col =>
88-
SparkEnv.get.heapMemoryAllocator.allocatePinnedMemory(col.columnType.bytes * size)
88+
SparkEnv.get.heapMemoryAllocator.allocateMemory(col.columnType.bytes * size)
8989
}
9090

9191
refCounter = 1
@@ -96,7 +96,7 @@ class ColumnPartitionData[T](
9696
if (blobs == null) {
9797
blobs = new Array(1)
9898
}
99-
val ptr = SparkEnv.get.heapMemoryAllocator.allocatePinnedMemory(blobSize)
99+
val ptr = SparkEnv.get.heapMemoryAllocator.allocateMemory(blobSize)
100100
blobs(0) = ptr
101101
if (blobBuffers == null) {
102102
blobBuffers = new Array(1)
@@ -148,9 +148,9 @@ class ColumnPartitionData[T](
148148
assert(refCounter > 0)
149149
refCounter -= 1
150150
if (refCounter == 0) {
151-
pointers.foreach(SparkEnv.get.heapMemoryAllocator.freePinnedMemory(_))
151+
pointers.foreach(SparkEnv.get.heapMemoryAllocator.freeMemory(_))
152152
if (blobs != null) {
153-
blobs.foreach(SparkEnv.get.heapMemoryAllocator.freePinnedMemory(_))
153+
blobs.foreach(SparkEnv.get.heapMemoryAllocator.freeMemory(_))
154154
}
155155
freeGPUPointers()
156156
}
@@ -608,7 +608,7 @@ class ColumnPartitionData[T](
608608
while (i < blobBuffersSize) {
609609
val blobSize = in.readLong()
610610
var blobOffset: Long = 8
611-
val ptr = SparkEnv.get.heapMemoryAllocator.allocatePinnedMemory(blobSize)
611+
val ptr = SparkEnv.get.heapMemoryAllocator.allocateMemory(blobSize)
612612
blobs(i) = ptr
613613
val byteBuffer = ptr.getByteBuffer(0, blobSize).order(ByteOrder.LITTLE_ENDIAN)
614614
blobBuffers(i) = byteBuffer

core/src/main/scala/org/apache/spark/rdd/RDD.scala

+13-2
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,16 @@ abstract class RDD[T: ClassTag](
203203
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
204204
def cache(): this.type = persist()
205205

206-
def cacheGpu() : RDD[T] = { sc.env.gpuMemoryManager.cacheGPUSlaves(id); this }
207-
def unCacheGpu() : RDD[T] = { sc.env.gpuMemoryManager.unCacheGPUSlaves(id); this }
206+
def cacheGpu() : RDD[T] = {
207+
sc.env.gpuMemoryManager.cacheGPUSlaves(id);
208+
storageGpuLevel = StorageLevel.MEMORY_ONLY
209+
this
210+
}
211+
def unCacheGpu() : RDD[T] = {
212+
sc.env.gpuMemoryManager.unCacheGPUSlaves(id);
213+
storageGpuLevel = StorageLevel.NONE
214+
this
215+
}
208216

209217
/**
210218
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
@@ -294,6 +302,8 @@ abstract class RDD[T: ClassTag](
294302
final def partitionData(split: Partition, context: TaskContext): PartitionData[T] = {
295303
if (storageLevel != StorageLevel.NONE) {
296304
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
305+
} else if (storageGpuLevel != StorageLevel.NONE) {
306+
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageGpuLevel)
297307
} else {
298308
computeOrReadCheckpoint(split, context)
299309
}
@@ -1717,6 +1727,7 @@ abstract class RDD[T: ClassTag](
17171727
// =======================================================================
17181728

17191729
private var storageLevel: StorageLevel = StorageLevel.NONE
1730+
private var storageGpuLevel: StorageLevel = StorageLevel.NONE
17201731

17211732
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
17221733
@transient private[spark] val creationSite = sc.getCallSite()

core/src/test/scala/org/apache/spark/cuda/CUDAFunctionSuite.scala

+11-6
Original file line numberDiff line numberDiff line change
@@ -954,17 +954,22 @@ BB_RET:
954954
Some((size: Long) => 1),
955955
Some(dimensions)))
956956

957-
def generateData: Array[DataPoint] = {
957+
def generateData(seed: Int, N: Int, D: Int, R: Double): DataPoint = {
958+
val r = new Random(seed)
958959
def generatePoint(i: Int): DataPoint = {
959960
val y = if (i % 2 == 0) -1 else 1
960-
val x = Array.fill(D){rand.nextGaussian + y * R}
961+
val x = Array.fill(D){r.nextGaussian + y * R}
961962
DataPoint(x, y)
962963
}
963-
Array.tabulate(N)(generatePoint)
964+
generatePoint(seed)
964965
}
965966

966-
val pointsCached = sc.parallelize(generateData, numSlices).cache()
967-
val pointsColumnCached = pointsCached.convert(ColumnFormat, false).cache().cacheGpu()
967+
val skelton = sc.parallelize((1 to N), numSlices)
968+
val pointsCached = skelton.map(i => generateData(i, N, D, R)).cache
969+
pointsCached.count()
970+
971+
val pointsColumnCached = pointsCached.convert(ColumnFormat, false).cacheGpu()
972+
pointsColumnCached.count()
968973

969974
// Initialize w to a random value
970975
var wCPU = Array.fill(D){2 * rand.nextDouble - 1}
@@ -1031,7 +1036,7 @@ BB_RET:
10311036
assert(r2.sameElements(r1.map(mulby2)))
10321037

10331038
// UncacheGPU should clear the GPU cache.
1034-
baseRDD.unCacheGpu().unCacheGpu()
1039+
baseRDD.unCacheGpu()
10351040
r1 = baseRDD.mapExtFunc((x: Int) => 2 * x, mapFunction).collect()
10361041
r2 = baseRDD.mapExtFunc((x: Int) => 2 * x, mapFunction).collect()
10371042
assert(r2.sameElements(r2))

examples/src/main/scala/org/apache/spark/examples/SparkGPULR.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ object SparkGPULR {
102102
Some(dimensions)))
103103

104104
val skelton = sc.parallelize((1 to N), numSlices)
105-
val points = skelton.map(i => generateData(i, N, D, R))
106-
val pointsColumnCached = points.convert(ColumnFormat).cache().cacheGpu()
105+
val points = skelton.map(i => generateData(i, N, D, R)).cache
106+
points.count()
107+
108+
val pointsColumnCached = points.convert(ColumnFormat).cacheGpu()
107109
pointsColumnCached.count()
108110

109111
// Initialize w to a random value

0 commit comments

Comments
 (0)