diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index e3bd5496cf5b..323a5d3c5283 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,8 +167,7 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 069e6d5f224d..4839d04522f1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,8 +248,7 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index ff85e11409e3..f713379326f6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,9 +23,6 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) - * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics( - val shuffleId: Int, - val bytesByPartitionId: Array[Long], - val recordsByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) + diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 41575ce4e6e3..42b66285b21f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -522,7 +522,6 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - val recordsByMapTask = new Array[Long](statuses.length) val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) @@ -530,11 +529,10 @@ private[spark] class MapOutputTrackerMaster( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - statuses.zipWithIndex.foreach { case (s, index) => + for (s <- statuses) { for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } - recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -551,11 +549,8 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } - statuses.zipWithIndex.foreach { case (s, index) => - recordsByMapTask(index) = s.numberOfOutput - } } - new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) + new MapOutputStatistics(dep.shuffleId, totalSizes) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 7e1d75fe723d..0e7d652cd5e3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, - * for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -45,23 +44,18 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long - - /** - * The number of outputs for the map task. - */ - def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { - HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) + HighlyCompressedMapStatus(loc, uncompressedSizes) } else { - new CompressedMapStatus(loc, uncompressedSizes, numOutput) + new CompressedMapStatus(loc, uncompressedSizes) } } @@ -104,34 +98,29 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte], - private[this] var numOutput: Long) + private[this] var compressedSizes: Array[Byte]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { - this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc - override def numberOfOutput: Long = numOutput - override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) - out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -154,20 +143,17 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte], - private[this] var numOutput: Long) + private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc - override def numberOfOutput: Long = numOutput - override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -182,7 +168,6 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) - out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -194,7 +179,6 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -210,10 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply( - loc: BlockManagerId, - uncompressedSizes: Array[Long], - numOutput: Long): HighlyCompressedMapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -254,6 +235,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap, numOutput) + hugeBlockSizesArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 91fc26762e53..274399b9cc1f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,8 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, - writeMetrics.recordsWritten) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index faa70f23b0ac..0d5c5ea7903e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,7 +233,6 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); - assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -253,7 +252,6 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); - assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index e79739692fe1..21f481d47724 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), 10)) + Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L), 10)) + Array(10000L, 1000L))) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000), 10)) + Array(compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000), 10)) + Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) + Array(compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) + Array(compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) + BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), 1)) + Array(2L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), 1)) + Array(2L))) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L), 1)) + Array(3L))) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000), 1)) + Array(size0, size1000, size0, size10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0), 1)) + Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 456f97b535ef..b917469e4874 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -391,7 +391,6 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) - assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 365eab0668ab..d6c9ae6ab519 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -445,17 +445,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -2857,7 +2857,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 555e48bd28aa..354e6386fa60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes, 1) + val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes, 1) + val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) val sizes = Array.fill[Long](500)(150L) // Test default value - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) assert(status.isInstanceOf[CompressedMapStatus]) // Test Non-positive values for (s <- -1 to 0) { assertThrows[IllegalArgumentException] { conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) } } // Test positive values Seq(1, 100, 499, 500, 501).foreach { s => conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) if(sizes.length > s) { assert(status.isInstanceOf[HighlyCompressedMapStatus]) } else { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 36912441c03b..ac25bcef5434 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,8 +345,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize( - HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) + ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cd28c733f361..a1e0d529c000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -208,18 +207,6 @@ case object SinglePartition extends Partitioning { } } -/** - * Represents a partitioning where rows are only serialized/deserialized locally. The number - * of partitions are not changed and also the distribution of rows. This is mainly used to - * obtain some statistics of map tasks such as number of outputs. - */ -case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { - val numPartitions = childRDD.getNumPartitions - - // We will perform this partitioning no matter what the data distribution is. - override def satisfies0(required: Distribution): Boolean = false -} - /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4928560eacb1..cccb528d6830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -255,13 +255,6 @@ object SQLConf { .intConf .createWithDefault(4) - val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") - .internal() - .doc("During global limit, try to evenly distribute limited rows across data " + - "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") - .booleanConf - .createWithDefault(true) - val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -1765,8 +1758,6 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) - def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) - def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 89442a70283f..5709f8016bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,35 +66,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { - private def decideTopRankNode(limit: Int, child: LogicalPlan): Seq[SparkPlan] = { - if (limit < conf.topKSortFallbackThreshold) { - child match { - case Sort(order, true, child) => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Project(projectList, Sort(order, true, child)) => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil - } - } else { - GlobalLimitExec(limit, - LocalLimitExec(limit, planLater(child)), - orderedLimit = true) :: Nil - } - } - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => - decideTopRankNode(limit, s) - case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => - decideTopRankNode(limit, p) + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => - decideTopRankNode(limit, s) - case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => - decideTopRankNode(limit, p) + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 9576605b1a21..aba94885f941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,11 +231,6 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } - case l: LocalPartitioning => - new Partitioner { - override def numPartitions: Int = l.numPartitions - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -252,9 +247,6 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity - case _: LocalPartitioning => - val partitionId = TaskContext.get().partitionId() - _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1a09632f93ca..bd9bfab0b1c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,16 +47,14 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Helper trait which defines methods that are shared by both + * [[LocalLimitExec]] and [[GlobalLimitExec]]. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { +trait BaseLimitExec extends UnaryExecNode with CodegenSupport { + val limit: Int override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -95,97 +93,27 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode wi } } + /** - * Take the `limit` elements of the child output. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -case class GlobalLimitExec(limit: Int, child: SparkPlan, - orderedLimit: Boolean = false) extends UnaryExecNode { +case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning +} - override def outputOrdering: Seq[SortOrder] = child.outputOrdering +/** + * Take the first `limit` elements of the child's single output partition. + */ +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - protected override def doExecute(): RDD[InternalRow] = { - val childRDD = child.execute() - val partitioner = LocalPartitioning(childRDD) - val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( - childRDD, child.output, partitioner, serializer) - val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { - // submitMapStage does not accept RDD with 0 partition. - // So, we will not submit this dependency. - val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) - submittedStageFuture.get().recordsByPartitionId.toSeq - } else { - Nil - } + override def outputPartitioning: Partitioning = child.outputPartitioning - // This is an optimization to evenly distribute limited rows across all partitions. - // When enabled, Spark goes to take rows at each partition repeatedly until reaching - // limit number. When disabled, Spark takes all rows at first partition, then rows - // at second partition ..., until reaching limit number. - // The optimization is disabled when it is needed to keep the original order of rows - // before global sort, e.g., select * from table order by col limit 10. - val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit - - val shuffled = new ShuffledRowRDD(shuffleDependency) - - val sumOfOutput = numberOfOutput.sum - if (sumOfOutput <= limit) { - shuffled - } else if (!flatGlobalLimit) { - var numRowTaken = 0 - val takeAmounts = numberOfOutput.map { num => - if (numRowTaken + num < limit) { - numRowTaken += num.toInt - num.toInt - } else { - val toTake = limit - numRowTaken - numRowTaken += toTake - toTake - } - } - val broadMap = sparkContext.broadcast(takeAmounts) - shuffled.mapPartitionsWithIndexInternal { case (index, iter) => - iter.take(broadMap.value(index).toInt) - } - } else { - // We try to evenly require the asked limit number of rows across all child rdd's partitions. - var rowsNeedToTake: Long = limit - val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) - val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) - - while (rowsNeedToTake > 0) { - val nonEmptyParts = remainingRowsByPartition.count(_ > 0) - // If the rows needed to take are less the number of non-empty partitions, take one row from - // each non-empty partitions until we reach `limit` rows. - // Otherwise, evenly divide the needed rows to each non-empty partitions. - val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) - remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => - // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during - // the traversal, so we need to add this check. - if (rowsNeedToTake > 0 && num > 0) { - if (num >= takePerPart) { - rowsNeedToTake -= takePerPart - takeAmountByPartition(index) += takePerPart - remainingRowsByPartition(index) -= takePerPart - } else { - rowsNeedToTake -= num - takeAmountByPartition(index) += num - remainingRowsByPartition(index) -= num - } - } - } - } - val broadMap = sparkContext.broadcast(takeAmountByPartition) - shuffled.mapPartitionsWithIndexInternal { case (index, iter) => - iter.take(broadMap.value(index).toInt) - } - } - } + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index e33cd819f281..b4c73cf33e53 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,5 +1,3 @@ --- Disable global limit parallel -set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a862e0985b20..4d901a57637a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,9 +1,6 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. --- Disable global limit optimization -set spark.sql.limit.flatGlobalLimit=false; - create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 187f3bd6858f..02fe1de84f75 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,134 +1,126 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 14 -- !query 0 -set spark.sql.limit.flatGlobalLimit=false --- !query 0 schema -struct --- !query 0 output -spark.sql.limit.flatGlobalLimit false - - --- !query 1 SELECT * FROM testdata LIMIT 2 --- !query 1 schema +-- !query 0 schema struct --- !query 1 output +-- !query 0 output 1 1 2 2 --- !query 2 +-- !query 1 SELECT * FROM arraydata LIMIT 2 --- !query 2 schema +-- !query 1 schema struct,nestedarraycol:array>> --- !query 2 output +-- !query 1 output [1,2,3] [[1,2,3]] [2,3,4] [[2,3,4]] --- !query 3 +-- !query 2 SELECT * FROM mapdata LIMIT 2 --- !query 3 schema +-- !query 2 schema struct> --- !query 3 output +-- !query 2 output {1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} {1:"a2",2:"b2",3:"c2",4:"d2"} --- !query 4 +-- !query 3 SELECT * FROM testdata LIMIT 2 + 1 --- !query 4 schema +-- !query 3 schema struct --- !query 4 output +-- !query 3 output 1 1 2 2 3 3 --- !query 5 +-- !query 4 SELECT * FROM testdata LIMIT CAST(1 AS int) --- !query 5 schema +-- !query 4 schema struct --- !query 5 output +-- !query 4 output 1 1 --- !query 6 +-- !query 5 SELECT * FROM testdata LIMIT -1 --- !query 6 schema +-- !query 5 schema struct<> --- !query 6 output +-- !query 5 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 7 +-- !query 6 SELECT * FROM testData TABLESAMPLE (-1 ROWS) --- !query 7 schema +-- !query 6 schema struct<> --- !query 7 output +-- !query 6 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 8 +-- !query 7 SELECT * FROM testdata LIMIT CAST(1 AS INT) --- !query 8 schema +-- !query 7 schema struct --- !query 8 output +-- !query 7 output 1 1 --- !query 9 +-- !query 8 SELECT * FROM testdata LIMIT CAST(NULL AS INT) --- !query 9 schema +-- !query 8 schema struct<> --- !query 9 output +-- !query 8 output org.apache.spark.sql.AnalysisException The evaluated limit expression must not be null, but got CAST(NULL AS INT); --- !query 10 +-- !query 9 SELECT * FROM testdata LIMIT key > 3 --- !query 10 schema +-- !query 9 schema struct<> --- !query 10 output +-- !query 9 output org.apache.spark.sql.AnalysisException The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); --- !query 11 +-- !query 10 SELECT * FROM testdata LIMIT true --- !query 11 schema +-- !query 10 schema struct<> --- !query 11 output +-- !query 10 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got boolean; --- !query 12 +-- !query 11 SELECT * FROM testdata LIMIT 'a' --- !query 12 schema +-- !query 11 schema struct<> --- !query 12 output +-- !query 11 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got string; --- !query 13 +-- !query 12 SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 --- !query 13 schema +-- !query 12 schema struct --- !query 13 output +-- !query 12 output 4 --- !query 14 +-- !query 13 SELECT * FROM testdata WHERE key < 3 LIMIT ALL --- !query 14 schema +-- !query 13 schema struct --- !query 14 output +-- !query 13 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 9eb5b3383e73..71ca1f864947 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,16 +1,8 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 8 -- !query 0 -set spark.sql.limit.flatGlobalLimit=false --- !query 0 schema -struct --- !query 0 output -spark.sql.limit.flatGlobalLimit false - - --- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -25,13 +17,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 1 schema +-- !query 0 schema struct<> --- !query 1 output +-- !query 0 output --- !query 2 +-- !query 1 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -47,13 +39,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 2 schema +-- !query 1 schema struct<> --- !query 2 output +-- !query 1 output --- !query 3 +-- !query 2 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -68,27 +60,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 3 schema +-- !query 2 schema struct<> --- !query 3 output +-- !query 2 output --- !query 4 +-- !query 3 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 4 schema +-- !query 3 schema struct --- !query 4 output +-- !query 3 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 4 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -96,16 +88,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 5 schema +-- !query 4 schema struct --- !query 5 output +-- !query 4 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 6 +-- !query 5 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -116,29 +108,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 6 schema +-- !query 5 schema struct --- !query 6 output +-- !query 5 output 1 NULL --- !query 7 +-- !query 6 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 7 schema +-- !query 6 schema struct --- !query 7 output +-- !query 6 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 8 +-- !query 7 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -149,7 +141,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 8 schema +-- !query 7 schema struct --- !query 8 output +-- !query 7 output 1 6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ed110f751645..89bae5f7186b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,13 +557,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value").repartition(1) - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) - } + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value").repartition(1) + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) } test("SPARK-17237 remove backticks in a pivot result schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f001b138f4b8..279b7b8d49f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ @@ -2552,26 +2552,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val baseDf = spark.range(1000).toDF.repartition(3).sort("id") - - withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { - val expected = baseDf.limit(99) - val takeOrderedNode1 = expected.queryExecution.executedPlan - .find(_.isInstanceOf[TakeOrderedAndProjectExec]) - assert(takeOrderedNode1.isDefined) - - val result = baseDf.limit(100) - val takeOrderedNode2 = result.queryExecution.executedPlan - .find(_.isInstanceOf[TakeOrderedAndProjectExec]) - assert(takeOrderedNode2.isEmpty) - - checkAnswer(expected, result.collect().take(99)) - } - } - } - test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { val df1 = spark.createDataFrame(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 01dc28d70184..8fcebb35a054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -524,15 +524,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("limit for skew dataframe") { - // Create a skew dataframe. - val df = testData.repartition(100).union(testData).limit(50) - // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, - // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` - // work on skew partitions. - assert(df.rdd.count() == 50L) - } - test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1944,7 +1935,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 order by a, b limit 1") + val df = sql("SELECT a, b from testData2 limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index c627c51655c8..6ad025f37e44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -55,7 +55,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) + new MapOutputStatistics(index, bytesByPartitionId) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -119,8 +119,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), - new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) + new MapOutputStatistics(0, bytesByPartitionId1), + new MapOutputStatistics(1, bytesByPartitionId2)) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala deleted file mode 100644 index a7840a5fcfae..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext - - -class LimitSuite extends SparkPlanTest with SharedSQLContext { - - private var rand: Random = _ - private var seed: Long = 0 - - protected override def beforeAll(): Unit = { - super.beforeAll() - seed = System.currentTimeMillis() - rand = new Random(seed) - } - - test("Produce ordered global limit if more than topKSortFallbackThreshold") { - withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { - val df = LimitTest.generateRandomInputData(spark, rand).sort("a") - - val globalLimit = df.limit(99).queryExecution.executedPlan.collect { - case g: GlobalLimitExec => g - } - assert(globalLimit.size == 0) - - val topKSort = df.limit(99).queryExecution.executedPlan.collect { - case t: TakeOrderedAndProjectExec => t - } - assert(topKSort.size == 1) - - val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect { - case g: GlobalLimitExec => g - } - assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true) - } - } - - test("Ordered global limit") { - val baseDf = LimitTest.generateRandomInputData(spark, rand) - .select("a").repartition(3).sort("a") - - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, - orderedLimit = true) - val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext) - .map(_.getInt(0)) - - val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false) - val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext) - .map(_.getInt(0)) - - // Global limit without order takes values at each partition sequentially. - // After global sort, the values in second partition must be larger than the values - // in first partition. - assert(orderedGlobalLimitResult(0) == globalLimitResult(0)) - assert(orderedGlobalLimitResult(1) < globalLimitResult(1)) - assert(orderedGlobalLimitResult(2) < globalLimitResult(2)) - } - } -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b10da6c70be1..e4e224df7607 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } { @@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 9322204063af..7e317a4d8026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -38,6 +37,14 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { rand = new Random(seed) } + private def generateRandomInputData(): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } + /** * Adds a no-op filter to the child plan in order to prevent executeCollect() from being * called directly on the child plan. @@ -48,62 +55,32 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - LimitTest.generateRandomInputData(spark, rand), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) - } + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - LimitTest.generateRandomInputData(spark, rand), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) - } - } - } - - test("TakeOrderedAndProject.doExecute equals to ordered global limit") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - LimitTest.generateRandomInputData(spark, rand), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input)), orderedLimit = true), - sortAnswers = false) - } + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) } } } - -object LimitTest { - def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = { - val schema = new StructType() - .add("a", IntegerType, nullable = false) - .add("b", IntegerType, nullable = false) - val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema) - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b9b2b7dbf38e..cebaad5b4ad9 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,7 +40,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled - private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -60,8 +59,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) - // Ensure that limit operation returns rows in the same order as Hive - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -76,7 +73,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 16541295eb45..cc592cf6ca62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,29 +22,21 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} -import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit - override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } - override def afterAll() { - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) - super.afterAll() - } // Column pruning tests