Skip to content

Commit 7610f2f

Browse files
committed
Add tests for proper cleanup of shuffle data.
1 parent d494ffe commit 7610f2f

File tree

3 files changed

+92
-8
lines changed

3 files changed

+92
-8
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
7272
true
7373
}
7474

75-
override def shuffleBlockResolver: IndexShuffleBlockResolver = {
75+
override val shuffleBlockResolver: IndexShuffleBlockResolver = {
7676
indexShuffleBlockResolver
7777
}
7878

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.shuffle.unsafe
1919

20+
import java.util.Collections
21+
import java.util.concurrent.ConcurrentHashMap
22+
2023
import org.apache.spark._
2124
import org.apache.spark.serializer.Serializer
2225
import org.apache.spark.shuffle._
@@ -25,7 +28,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager
2528
/**
2629
* Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
2730
*/
28-
private class UnsafeShuffleHandle[K, V](
31+
private[spark] class UnsafeShuffleHandle[K, V](
2932
shuffleId: Int,
3033
numMaps: Int,
3134
dependency: ShuffleDependency[K, V, V])
@@ -121,8 +124,10 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
121124
"manager; its optimized shuffles will continue to spill to disk when necessary.")
122125
}
123126

124-
125127
private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
128+
private[this] val shufflesThatFellBackToSortShuffle =
129+
Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
130+
private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
126131

127132
/**
128133
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
@@ -158,8 +163,8 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
158163
context: TaskContext): ShuffleWriter[K, V] = {
159164
handle match {
160165
case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
166+
numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
161167
val env = SparkEnv.get
162-
// TODO: do we need to do anything to register the shuffle here?
163168
new UnsafeShuffleWriter(
164169
env.blockManager,
165170
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
@@ -170,17 +175,26 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
170175
context,
171176
env.conf)
172177
case other =>
178+
shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
173179
sortShuffleManager.getWriter(handle, mapId, context)
174180
}
175181
}
176182

177183
/** Remove a shuffle's metadata from the ShuffleManager. */
178184
override def unregisterShuffle(shuffleId: Int): Boolean = {
179-
// TODO: need to do something here for our unsafe path
180-
sortShuffleManager.unregisterShuffle(shuffleId)
185+
if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
186+
sortShuffleManager.unregisterShuffle(shuffleId)
187+
} else {
188+
Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
189+
(0 until numMaps).foreach { mapId =>
190+
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
191+
}
192+
}
193+
true
194+
}
181195
}
182196

183-
override def shuffleBlockResolver: ShuffleBlockResolver = {
197+
override val shuffleBlockResolver: IndexShuffleBlockResolver = {
184198
sortShuffleManager.shuffleBlockResolver
185199
}
186200

core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@
1717

1818
package org.apache.spark.shuffle.unsafe
1919

20-
import org.apache.spark.ShuffleSuite
20+
import scala.collection.JavaConverters._
21+
22+
import org.apache.commons.io.FileUtils
23+
import org.apache.commons.io.filefilter.TrueFileFilter
2124
import org.scalatest.BeforeAndAfterAll
2225

26+
import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
27+
import org.apache.spark.rdd.ShuffledRDD
28+
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
29+
import org.apache.spark.util.Utils
30+
2331
class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
2432

2533
// This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
@@ -30,4 +38,66 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
3038
// shuffle records.
3139
conf.set("spark.shuffle.memoryFraction", "0.5")
3240
}
41+
42+
test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
43+
val tmpDir = Utils.createTempDir()
44+
try {
45+
val myConf = conf.clone()
46+
.set("spark.local.dir", tmpDir.getAbsolutePath)
47+
sc = new SparkContext("local", "test", myConf)
48+
// Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
49+
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
50+
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
51+
.setSerializer(new KryoSerializer(myConf))
52+
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
53+
assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
54+
def getAllFiles =
55+
FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
56+
val filesBeforeShuffle = getAllFiles
57+
// Force the shuffle to be performed
58+
shuffledRdd.count()
59+
// Ensure that the shuffle actually created files that will need to be cleaned up
60+
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
61+
filesCreatedByShuffle.map(_.getName) should be
62+
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
63+
// Check that the cleanup actually removes the files
64+
sc.env.blockManager.master.removeShuffle(0, blocking = true)
65+
for (file <- filesCreatedByShuffle) {
66+
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
67+
}
68+
} finally {
69+
Utils.deleteRecursively(tmpDir)
70+
}
71+
}
72+
73+
test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
74+
val tmpDir = Utils.createTempDir()
75+
try {
76+
val myConf = conf.clone()
77+
.set("spark.local.dir", tmpDir.getAbsolutePath)
78+
sc = new SparkContext("local", "test", myConf)
79+
// Create a shuffled RDD and verify that it will actually use the old SortShuffle path
80+
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
81+
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
82+
.setSerializer(new JavaSerializer(myConf))
83+
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
84+
assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
85+
def getAllFiles =
86+
FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
87+
val filesBeforeShuffle = getAllFiles
88+
// Force the shuffle to be performed
89+
shuffledRdd.count()
90+
// Ensure that the shuffle actually created files that will need to be cleaned up
91+
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
92+
filesCreatedByShuffle.map(_.getName) should be
93+
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
94+
// Check that the cleanup actually removes the files
95+
sc.env.blockManager.master.removeShuffle(0, blocking = true)
96+
for (file <- filesCreatedByShuffle) {
97+
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
98+
}
99+
} finally {
100+
Utils.deleteRecursively(tmpDir)
101+
}
102+
}
33103
}

0 commit comments

Comments
 (0)