diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c5283..e3bd5496cf5b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +167,8 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f1..069e6d5f224d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,7 +248,8 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d0d8cb..ff85e11409e3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,5 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) + * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val recordsByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4bc6541..41575ce4e6e3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -522,16 +522,19 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) + val recordsByMapTask = new Array[Long](statuses.length) + val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - for (s <- statuses) { + statuses.zipWithIndex.foreach { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } + recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -548,8 +551,11 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } + statuses.zipWithIndex.foreach { case (s, index) => + recordsByMapTask(index) = s.numberOfOutput + } } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 659694dd189a..7e1d75fe723d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, + * for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -44,18 +45,23 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + /** + * The number of outputs for the map task. + */ + def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { if (uncompressedSizes.length > Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, numOutput) } } @@ -98,29 +104,34 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -143,17 +154,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -168,6 +182,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -179,6 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -194,7 +210,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + numOutput: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -235,6 +254,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizesArray.toMap, numOutput) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f..91fc26762e53 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, + writeMetrics.recordsWritten) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e..faa70f23b0ac 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,6 +233,7 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -252,6 +253,7 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d47724..e79739692fe1 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 10)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 10)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 1)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), 1)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index ced5a06516f7..d11eaf8c2749 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -391,6 +391,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index dad339e2cdb9..b927d6a16ecb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -423,17 +423,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -2579,7 +2579,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 354e6386fa60..555e48bd28aa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) val sizes = Array.fill[Long](500)(150L) // Test default value - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[CompressedMapStatus]) // Test Non-positive values for (s <- -1 to 0) { assertThrows[IllegalArgumentException] { conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) } } // Test positive values Seq(1, 100, 499, 500, 501).foreach { s => conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) if(sizes.length > s) { assert(status.isInstanceOf[HighlyCompressedMapStatus]) } else { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index fc78655bf52e..240f8cf800fe 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,7 +345,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize( + HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e835d9c..cd28c733f361 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -206,6 +208,18 @@ case object SinglePartition extends Partitioning { } } +/** + * Represents a partitioning where rows are only serialized/deserialized locally. The number + * of partitions are not changed and also the distribution of rows. This is mainly used to + * obtain some statistics of map tasks such as number of outputs. + */ +case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { + val numPartitions = childRDD.getNumPartitions + + // We will perform this partitioning no matter what the data distribution is. + override def satisfies0(required: Distribution): Boolean = false +} + /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 67c3abb80c2c..966ad687ad2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -214,6 +214,13 @@ object SQLConf { .intConf .createWithDefault(4) + val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") + .internal() + .doc("During global limit, try to evenly distribute limited rows across data " + + "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") + .booleanConf + .createWithDefault(true) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -1673,6 +1680,8 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) + def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b89203719541..50f10c31427d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,6 +231,11 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case l: LocalPartitioning => + new Partitioner { + override def numPartitions: Int = l.numPartitions + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -247,6 +252,9 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case _: LocalPartitioning => + val partitionId = TaskContext.get().partitionId() + _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda891373..392ca13724bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,13 +47,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the `limit` elements of the child output. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning -} -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { + val childRDD = child.execute() + val partitioner = LocalPartitioning(childRDD) + val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( + childRDD, child.output, partitioner, serializer) + val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) + submittedStageFuture.get().recordsByPartitionId.toSeq + } else { + Nil + } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + // During global limit, try to evenly distribute limited rows across data + // partitions. If disabled, scanning data partitions sequentially until reaching limit number. + // Besides, if child output has certain ordering, we can't evenly pick up rows from + // each parititon. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && child.outputOrdering == Nil + + val shuffled = new ShuffledRowRDD(shuffleDependency) + + val sumOfOutput = numberOfOutput.sum + if (sumOfOutput <= limit) { + shuffled + } else if (!flatGlobalLimit) { + var numRowTaken = 0 + val takeAmounts = numberOfOutput.map { num => + if (numRowTaken + num < limit) { + numRowTaken += num.toInt + num.toInt + } else { + val toTake = limit - numRowTaken + numRowTaken += toTake + toTake + } + } + val broadMap = sparkContext.broadcast(takeAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } else { + // We try to evenly require the asked limit number of rows across all child rdd's partitions. + var rowsNeedToTake: Long = limit + val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) + val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) + + while (rowsNeedToTake > 0) { + val nonEmptyParts = remainingRowsByPartition.count(_ > 0) + // If the rows needed to take are less the number of non-empty partitions, take one row from + // each non-empty partitions until we reach `limit` rows. + // Otherwise, evenly divide the needed rows to each non-empty partitions. + val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) + remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => + // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during + // the traversal, so we need to add this check. + if (rowsNeedToTake > 0 && num > 0) { + if (num >= takePerPart) { + rowsNeedToTake -= takePerPart + takeAmountByPartition(index) += takePerPart + remainingRowsByPartition(index) -= takePerPart + } else { + rowsNeedToTake -= num + takeAmountByPartition(index) += num + remainingRowsByPartition(index) -= num + } + } + } + } + val broadMap = sparkContext.broadcast(takeAmountByPartition) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } + } } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index b4c73cf33e53..e33cd819f281 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,3 +1,5 @@ +-- Disable global limit parallel +set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a40ee082ba3b..a862e0985b20 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,6 +1,9 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- Disable global limit optimization +set spark.sql.limit.flatGlobalLimit=false; + create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -97,4 +100,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; \ No newline at end of file +LIMIT 1; diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 02fe1de84f75..187f3bd6858f 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,63 +1,62 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 15 -- !query 0 -SELECT * FROM testdata LIMIT 2 +set spark.sql.limit.flatGlobalLimit=false -- !query 0 schema -struct +struct -- !query 0 output -1 1 -2 2 +spark.sql.limit.flatGlobalLimit false -- !query 1 -SELECT * FROM arraydata LIMIT 2 +SELECT * FROM testdata LIMIT 2 -- !query 1 schema -struct,nestedarraycol:array>> +struct -- !query 1 output -[1,2,3] [[1,2,3]] -[2,3,4] [[2,3,4]] +1 1 +2 2 -- !query 2 -SELECT * FROM mapdata LIMIT 2 +SELECT * FROM arraydata LIMIT 2 -- !query 2 schema -struct> +struct,nestedarraycol:array>> -- !query 2 output -{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} -{1:"a2",2:"b2",3:"c2",4:"d2"} +[1,2,3] [[1,2,3]] +[2,3,4] [[2,3,4]] -- !query 3 -SELECT * FROM testdata LIMIT 2 + 1 +SELECT * FROM mapdata LIMIT 2 -- !query 3 schema -struct +struct> -- !query 3 output -1 1 -2 2 -3 3 +{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} +{1:"a2",2:"b2",3:"c2",4:"d2"} -- !query 4 -SELECT * FROM testdata LIMIT CAST(1 AS int) +SELECT * FROM testdata LIMIT 2 + 1 -- !query 4 schema struct -- !query 4 output 1 1 +2 2 +3 3 -- !query 5 -SELECT * FROM testdata LIMIT -1 +SELECT * FROM testdata LIMIT CAST(1 AS int) -- !query 5 schema -struct<> +struct -- !query 5 output -org.apache.spark.sql.AnalysisException -The limit expression must be equal to or greater than 0, but got -1; +1 1 -- !query 6 -SELECT * FROM testData TABLESAMPLE (-1 ROWS) +SELECT * FROM testdata LIMIT -1 -- !query 6 schema struct<> -- !query 6 output @@ -66,61 +65,70 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -SELECT * FROM testdata LIMIT CAST(1 AS INT) +SELECT * FROM testData TABLESAMPLE (-1 ROWS) -- !query 7 schema -struct +struct<> -- !query 7 output -1 1 +org.apache.spark.sql.AnalysisException +The limit expression must be equal to or greater than 0, but got -1; -- !query 8 -SELECT * FROM testdata LIMIT CAST(NULL AS INT) +SELECT * FROM testdata LIMIT CAST(1 AS INT) -- !query 8 schema -struct<> +struct -- !query 8 output -org.apache.spark.sql.AnalysisException -The evaluated limit expression must not be null, but got CAST(NULL AS INT); +1 1 -- !query 9 -SELECT * FROM testdata LIMIT key > 3 +SELECT * FROM testdata LIMIT CAST(NULL AS INT) -- !query 9 schema struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +The evaluated limit expression must not be null, but got CAST(NULL AS INT); -- !query 10 -SELECT * FROM testdata LIMIT true +SELECT * FROM testdata LIMIT key > 3 -- !query 10 schema struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 11 -SELECT * FROM testdata LIMIT 'a' +SELECT * FROM testdata LIMIT true -- !query 11 schema struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must be integer type, but got boolean; -- !query 12 -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +SELECT * FROM testdata LIMIT 'a' -- !query 12 schema -struct +struct<> -- !query 12 output -4 +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; -- !query 13 -SELECT * FROM testdata WHERE key < 3 LIMIT ALL +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 -- !query 13 schema -struct +struct -- !query 13 output +4 + + +-- !query 14 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 14 schema +struct +-- !query 14 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 71ca1f864947..9eb5b3383e73 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,8 +1,16 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 +set spark.sql.limit.flatGlobalLimit=false +-- !query 0 schema +struct +-- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -17,13 +25,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 0 schema +-- !query 1 schema struct<> --- !query 0 output +-- !query 1 output --- !query 1 +-- !query 2 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -39,13 +47,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 1 schema +-- !query 2 schema struct<> --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -60,27 +68,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 2 schema +-- !query 3 schema struct<> --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 4 +-- !query 5 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -88,16 +96,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 6 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -108,29 +116,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 1 NULL --- !query 6 +-- !query 7 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 7 +-- !query 8 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -141,7 +149,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output 1 6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db..85b3ca11383f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,11 +557,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } test("SPARK-17237 remove backticks in a pivot result schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 86083d1701c2..7546930ab1d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -523,6 +523,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1934,7 +1943,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") + val df = sql("SELECT a, b from testData2 order by a, b limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b736d43bfc6b..41de731d41f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) + new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d254345e8fa5..0cbe771b54ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cebaad5b4ad9..b9b2b7dbf38e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -59,6 +60,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Ensure that limit operation returns rows in the same order as Hive + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -73,6 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index cc592cf6ca62..16541295eb45 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,21 +22,29 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} +import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } + override def afterAll() { + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) + super.afterAll() + } // Column pruning tests