diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index be63c637a3a1..168d964c99d0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -300,6 +300,13 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_HIGHLY_COMPRESSED_MAPSTATUS_THRESHOLD = + ConfigBuilder("spark.shuffle.highlyCompressedMapStatusThreshold") + .doc("Compress the size of shuffle blocks in HighlyCompressedMapStatus when the number of" + + "reduce partitions is above this threshold.") + .intConf + .createWithDefault(2000) + private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 5e45b375ddd4..d9342904872b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -50,7 +50,10 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_HIGHLY_COMPRESSED_MAPSTATUS_THRESHOLD)) + .getOrElse(config.SHUFFLE_HIGHLY_COMPRESSED_MAPSTATUS_THRESHOLD.defaultValue.get) + if (uncompressedSizes.length > threshold) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index e6120139f495..13487584b22a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -31,6 +31,11 @@ import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { + val env = mock(classOf[SparkEnv]) + SparkEnv.set(env) + val conf = new SparkConf() + doReturn(conf).when(env).conf + test("compressSize") { assert(MapStatus.compressSize(0L) === 0) assert(MapStatus.compressSize(1L) === 1) @@ -71,17 +76,20 @@ class MapStatusSuite extends SparkFunSuite { } } - test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { - val sizes = Array.fill[Long](2001)(150L) + test("large tasks(over spark.shuffle.highlyCompressedMapStatusThreshold) should use " + + classOf[HighlyCompressedMapStatus].getName) { + conf.set(config.SHUFFLE_HIGHLY_COMPRESSED_MAPSTATUS_THRESHOLD.key, "1000") + val sizes = Array.fill[Long](1001)(150L) val status = MapStatus(null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) assert(status.getSizeForBlock(99) === 150L) - assert(status.getSizeForBlock(2000) === 150L) + assert(status.getSizeForBlock(1000) === 150L) } test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") { + conf.set(config.SHUFFLE_HIGHLY_COMPRESSED_MAPSTATUS_THRESHOLD.key, "2000") val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) @@ -135,10 +143,8 @@ class MapStatusSuite extends SparkFunSuite { test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " + "underestimated.") { - val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000") - val env = mock(classOf[SparkEnv]) - doReturn(conf).when(env).conf - SparkEnv.set(env) + conf.set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000") + .set(config.SHUFFLE_HIGHLY_COMPRESSED_MAPSTATUS_THRESHOLD.key, "2000") // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)