1717
1818package org .apache .spark .shuffle .unsafe
1919
20- import org .apache .spark .ShuffleSuite
20+ import scala .collection .JavaConverters ._
21+
22+ import org .apache .commons .io .FileUtils
23+ import org .apache .commons .io .filefilter .TrueFileFilter
2124import org .scalatest .BeforeAndAfterAll
2225
26+ import org .apache .spark .{HashPartitioner , ShuffleDependency , SparkContext , ShuffleSuite }
27+ import org .apache .spark .rdd .ShuffledRDD
28+ import org .apache .spark .serializer .{JavaSerializer , KryoSerializer }
29+ import org .apache .spark .util .Utils
30+
2331class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
2432
2533 // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
@@ -30,4 +38,66 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
3038 // shuffle records.
3139 conf.set(" spark.shuffle.memoryFraction" , " 0.5" )
3240 }
41+
42+ test(" UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path" ) {
43+ val tmpDir = Utils .createTempDir()
44+ try {
45+ val myConf = conf.clone()
46+ .set(" spark.local.dir" , tmpDir.getAbsolutePath)
47+ sc = new SparkContext (" local" , " test" , myConf)
48+ // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
49+ val rdd = sc.parallelize(1 to 10 , 1 ).map(x => (x, x))
50+ val shuffledRdd = new ShuffledRDD [Int , Int , Int ](rdd, new HashPartitioner (4 ))
51+ .setSerializer(new KryoSerializer (myConf))
52+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf [ShuffleDependency [_, _, _]]
53+ assert(UnsafeShuffleManager .canUseUnsafeShuffle(shuffleDep))
54+ def getAllFiles =
55+ FileUtils .listFiles(tmpDir, TrueFileFilter .INSTANCE , TrueFileFilter .INSTANCE ).asScala.toSet
56+ val filesBeforeShuffle = getAllFiles
57+ // Force the shuffle to be performed
58+ shuffledRdd.count()
59+ // Ensure that the shuffle actually created files that will need to be cleaned up
60+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
61+ filesCreatedByShuffle.map(_.getName) should be
62+ Set (" shuffle_0_0_0.data" , " shuffle_0_0_0.index" )
63+ // Check that the cleanup actually removes the files
64+ sc.env.blockManager.master.removeShuffle(0 , blocking = true )
65+ for (file <- filesCreatedByShuffle) {
66+ assert (! file.exists(), s " Shuffle file $file was not cleaned up " )
67+ }
68+ } finally {
69+ Utils .deleteRecursively(tmpDir)
70+ }
71+ }
72+
73+ test(" UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path" ) {
74+ val tmpDir = Utils .createTempDir()
75+ try {
76+ val myConf = conf.clone()
77+ .set(" spark.local.dir" , tmpDir.getAbsolutePath)
78+ sc = new SparkContext (" local" , " test" , myConf)
79+ // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
80+ val rdd = sc.parallelize(1 to 10 , 1 ).map(x => (x, x))
81+ val shuffledRdd = new ShuffledRDD [Int , Int , Int ](rdd, new HashPartitioner (4 ))
82+ .setSerializer(new JavaSerializer (myConf))
83+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf [ShuffleDependency [_, _, _]]
84+ assert(! UnsafeShuffleManager .canUseUnsafeShuffle(shuffleDep))
85+ def getAllFiles =
86+ FileUtils .listFiles(tmpDir, TrueFileFilter .INSTANCE , TrueFileFilter .INSTANCE ).asScala.toSet
87+ val filesBeforeShuffle = getAllFiles
88+ // Force the shuffle to be performed
89+ shuffledRdd.count()
90+ // Ensure that the shuffle actually created files that will need to be cleaned up
91+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
92+ filesCreatedByShuffle.map(_.getName) should be
93+ Set (" shuffle_0_0_0.data" , " shuffle_0_0_0.index" )
94+ // Check that the cleanup actually removes the files
95+ sc.env.blockManager.master.removeShuffle(0 , blocking = true )
96+ for (file <- filesCreatedByShuffle) {
97+ assert (! file.exists(), s " Shuffle file $file was not cleaned up " )
98+ }
99+ } finally {
100+ Utils .deleteRecursively(tmpDir)
101+ }
102+ }
33103}
0 commit comments