From 4e31bb7959cb774b51d6d8662f53a3ad96b4dc49 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Jan 2017 08:21:04 +0000 Subject: [PATCH 01/18] Use map output statistices to improve global limit's parallelism. --- .../sort/BypassMergeSortShuffleWriter.java | 7 +- .../shuffle/sort/UnsafeShuffleWriter.java | 6 +- .../apache/spark/MapOutputStatistics.scala | 6 +- .../org/apache/spark/MapOutputTracker.scala | 6 +- .../apache/spark/scheduler/MapStatus.scala | 40 +++++-- .../shuffle/sort/SortShuffleWriter.scala | 4 +- .../util/collection/ExternalSorter.scala | 7 +- .../sort/UnsafeShuffleWriterSuite.java | 2 + .../apache/spark/MapOutputTrackerSuite.scala | 24 ++-- .../scala/org/apache/spark/ShuffleSuite.scala | 1 + .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/MapStatusSuite.scala | 6 +- .../serializer/KryoSerializerSuite.scala | 3 +- .../plans/physical/partitioning.scala | 15 +++ .../execution/exchange/ShuffleExchange.scala | 10 ++ .../apache/spark/sql/execution/limit.scala | 111 +++++++++++++++--- .../apache/spark/sql/internal/SQLConf.scala | 11 ++ .../test/resources/sql-tests/inputs/limit.sql | 2 +- .../resources/sql-tests/results/limit.sql.out | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 12 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +- .../execution/ExchangeCoordinatorSuite.scala | 6 +- .../spark/sql/execution/PlannerSuite.scala | 4 +- .../TakeOrderedAndProjectSuite.scala | 49 ++++---- .../execution/HiveCompatibilitySuite.scala | 4 + .../sql/hive/execution/PruningSuite.scala | 8 ++ 26 files changed, 271 insertions(+), 88 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a15559e55cb..125ac685409d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -126,7 +126,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -146,10 +146,12 @@ public void write(Iterator> records) throws IOException { // included in the shuffle write time. writeMetrics.incWriteTime(System.nanoTime() - openStartTime); + int numOfRecords = 0; while (records.hasNext()) { final Product2 record = records.next(); final K key = record._1(); partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + numOfRecords += 1; } for (int i = 0; i < numPartitions; i++) { @@ -168,7 +170,8 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, numOfRecords); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 8a1771848dee..05e3d0d28746 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -82,6 +82,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; + private int numOfRecords = 0; + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { MyByteArrayOutputStream(int size) { super(size); } @@ -165,6 +167,7 @@ public void write(scala.collection.Iterator> records) throws IOEx try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); + numOfRecords += 1; } closeAndWriteOutput(); success = true; @@ -227,7 +230,8 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, numOfRecords); } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d0d8cb..2d47137f3ed7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,5 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) + * @param numberOfOutput number of output for each pre-map output partition */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val numberOfOutput: Array[Int]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ca442b629fd..b8e5cca10365 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -156,12 +156,14 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // Synchronize on the returned array because, on the driver, it gets mutated in place statuses.synchronized { val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { + val numberOfOutput = new Array[Int](statuses.length) + statuses.zipWithIndex.map { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } + numberOfOutput(index) = s.numberOfOutput } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, numberOfOutput) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index b2e9a97129f0..364bb495805e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -26,7 +26,8 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, + * for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -39,16 +40,18 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + def numberOfOutput: Int } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Int): MapStatus = { if (uncompressedSizes.length > 2000) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, numOutput) } } @@ -91,29 +94,34 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var numOutput: Int) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Int) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc + override def numberOfOutput: Int = numOutput + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeInt(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readInt() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -133,17 +141,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, - private[this] var avgSize: Long) + private[this] var avgSize: Long, + private[this] var numOutput: Int) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, -1) // For deserialization only override def location: BlockManagerId = loc + override def numberOfOutput: Int = numOutput + override def getSizeForBlock(reduceId: Int): Long = { if (emptyBlocks.contains(reduceId)) { 0 @@ -154,12 +165,14 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeInt(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readInt() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -167,7 +180,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + numOutput: Int): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -195,6 +211,6 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) + new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, numOutput) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792bf..2f0e0cd83d73 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -60,7 +60,7 @@ private[spark] class SortShuffleWriter[K, V, C]( new ExternalSorter[K, V, V]( context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } - sorter.insertAll(records) + val numOfRecords = sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -71,7 +71,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, numOfRecords) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d..6ad027e3ffb0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -176,10 +176,12 @@ private[spark] class ExternalSorter[K, V, C]( */ private[spark] def numSpills: Int = spills.size - def insertAll(records: Iterator[Product2[K, V]]): Unit = { + def insertAll(records: Iterator[Product2[K, V]]): Int = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined + var numOfRecords: Int = 0 + if (shouldCombine) { // Combine values in-memory first using our AppendOnlyMap val mergeValue = aggregator.get.mergeValue @@ -193,6 +195,7 @@ private[spark] class ExternalSorter[K, V, C]( kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) + numOfRecords += 1 } } else { // Stick values into our buffer @@ -201,8 +204,10 @@ private[spark] class ExternalSorter[K, V, C]( val kv = records.next() buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) + numOfRecords += 1 } } + numOfRecords } /** diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 088b68132d90..ad67226efb4d 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -246,6 +246,7 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.>emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -265,6 +266,7 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index bb24c6ce4d33..51a6e2bcf0c3 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -61,9 +61,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 10)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -83,9 +83,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 10)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -106,9 +106,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -144,7 +144,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === @@ -180,7 +180,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -214,11 +214,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 1)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -259,7 +259,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index e626ed3621d6..5dec2f8ade51 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -364,6 +364,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index f3d3f701af46..5a3f1a36a56d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2215,7 +2215,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 759d52fca5ce..8e8f726a29a9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -55,7 +55,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -69,7 +69,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -81,7 +81,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a30653bb36fa..d47509759876 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -335,7 +335,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize( + HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233f..9aaa4c7d2e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -229,6 +229,21 @@ case object SinglePartition extends Partitioning { override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } +/** + * Represents a partitioning where rows are only serialized/deserialized locally. The number + * of partitions are not changed and also the distribution of rows. This is mainly used to + * obtain some statistics of map tasks such as number of outputs. + */ +case class LocalPartitioning(orgPartition: Partitioning, numPartitions: Int) extends Partitioning { + // We will perform this partitioning no matter what the data distribution is. + override def satisfies(required: Distribution): Boolean = false + + override def compatibleWith(other: Partitioning): Boolean = + orgPartition.compatibleWith(other) + + override def guarantees(other: Partitioning): Boolean = orgPartition.guarantees(other) +} + /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 125a4930c652..95c5dbbcc50d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -221,6 +221,12 @@ object ShuffleExchange { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case LocalPartitioning(prev, numParts) => + new Partitioner { + override def numPartitions: Int = numParts + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } + case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -237,6 +243,10 @@ object ShuffleExchange { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case LocalPartitioning(_, _) => + (row: InternalRow) => { + TaskContext.get().partitionId() + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") } val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d30..7fd480c28f6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import scala.collection.mutable + import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow @@ -47,13 +49,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -90,25 +95,101 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the `limit` elements of the child output. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning -} -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + protected override def doExecute(): RDD[InternalRow] = { + val childRDD = child.execute() + val partitioner = LocalPartitioning(child.outputPartitioning, + childRDD.getNumPartitions) + val shuffleDependency = ShuffleExchange.prepareShuffleDependency( + childRDD, child.output, partitioner, serializer) + val numberOfOutput: Seq[Int] = if (shuffleDependency.rdd.getNumPartitions != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) + submittedStageFuture.get().numberOfOutput.toSeq + } else { + Nil + } - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + // Try to keep child plan's original data parallelism or not. It is enabled by default. + val respectChildParallelism = sqlContext.conf.enableParallelGlobalLimit - override def outputPartitioning: Partitioning = child.outputPartitioning + val shuffled = new ShuffledRowRDD(shuffleDependency) - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + val sumOfOutput = numberOfOutput.sum + if (sumOfOutput <= limit) { + shuffled + } else if (!respectChildParallelism) { + // This is mainly for tests. + // We take the rows of each partition until we reach the required limit number. + var countForRows = 0 + val takeAmounts = new mutable.HashMap[Int, Int]() + numberOfOutput.zipWithIndex.foreach { case (num, index) => + if (countForRows + num < limit) { + countForRows += num + takeAmounts += ((index, num)) + } else { + val toTake = limit - countForRows + countForRows += toTake + takeAmounts += ((index, toTake)) + } + } + val broadMap = sparkContext.broadcast(takeAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + broadMap.value.get(index).map { size => + iter.take(size) + }.get + } + } else { + // We try to distribute the required limit number of rows across all child rdd's partitions. + var numToReduce = (sumOfOutput - limit) + val reduceAmounts = new mutable.HashMap[Int, Int]() + val nonEmptyParts = numberOfOutput.filter(_ > 0).size + val reducePerPart = numToReduce / nonEmptyParts + numberOfOutput.zipWithIndex.foreach { case (num, index) => + if (num >= reducePerPart) { + numToReduce -= reducePerPart + reduceAmounts += ((index, reducePerPart)) + } else { + numToReduce -= num + reduceAmounts += ((index, num)) + } + } + while (numToReduce > 0) { + numberOfOutput.zipWithIndex.foreach { case (num, index) => + val toReduce = if (numToReduce / nonEmptyParts > 0) { + numToReduce / nonEmptyParts + } else { + numToReduce + } + if (num - reduceAmounts(index) >= toReduce) { + reduceAmounts(index) = reduceAmounts(index) + toReduce + numToReduce -= toReduce + } else if (num - reduceAmounts(index) > 0) { + reduceAmounts(index) = reduceAmounts(index) + 1 + numToReduce -= 1 + } + } + } + val broadMap = sparkContext.broadcast(reduceAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + broadMap.value.get(index).map { size => + iter.drop(size) + }.get + } + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 645b0fa13ee3..bb1c2a73eb7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -122,6 +122,15 @@ object SQLConf { .intConf .createWithDefault(4) + val ENABLE_PARALLEL_GLOBAL_LIMIT = SQLConfigBuilder("spark.sql.limit.globalparallel") + .internal() + .doc("Not to shuffle the results of local limit to one single partition in global limit " + + "so that the limit operation doesn't downgrade parallelism. The config is mainly used " + + "in tests especially Hive compatibility test cases which assume there is an order in " + + "the returned rows of limit operation.") + .booleanConf + .createWithDefault(true) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = SQLConfigBuilder("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -780,6 +789,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def enableParallelGlobalLimit: Boolean = getConf(ENABLE_PARALLEL_GLOBAL_LIMIT) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index 2ea35f7f3a5c..cb6bc66ff178 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -20,4 +20,4 @@ select * from testdata limit true; select * from testdata limit 'a'; -- limit within a subquery -select * from (select * from range(10) limit 5) where id > 3; +select * from (select * from range(10) order by id limit 5) where id > 3; diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index cb4e4d04810d..da1a99b73891 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -84,7 +84,7 @@ The limit expression must be integer type, but got string; -- !query 9 -select * from (select * from range(10) limit 5) where id > 3 +select * from (select * from range(10) order by id limit 5) where id > 3 -- !query 9 schema struct -- !query 9 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7d..4ca253d1a07e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -524,11 +524,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) + withSQLConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT.key -> "false") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } test("SPARK-17237 remove backticks in a pivot result schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 605dec4a1ef9..f5cd339d6677 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -530,6 +530,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(e.contains(expected)) } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1887,7 +1896,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") + val df = sql("SELECT a, b from testData2 order by a, b limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 06bce9a2400e..379d86157683 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) + new MapOutputStatistics(index, bytesByPartitionId, Array[Int](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Int](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Int](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0bfc92fdb621..0ebb3eb32fac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -216,7 +216,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchange => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -231,7 +231,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchange => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 7e317a4d8026..5ac0ca59a863 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -55,32 +56,36 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) + withSQLConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) + } } } test("TakeOrderedAndProject.doExecute with project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) + withSQLConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) + } } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d217e9b4feb6..09b2124fcc9c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,6 +41,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalParallelGlobalLimit = TestHive.conf.enableParallelGlobalLimit def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -62,6 +63,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Ensure that limit operation returns rows in the same order as Hive + TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, false) RuleExecutor.resetTime() } @@ -74,6 +77,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, originalParallelGlobalLimit) // For debugging dump some statistics about how much time was spent in various optimizer rules logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 24df73b40ea0..5c406c0aab74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,21 +22,29 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} +import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + private val originalParallelGlobalLimit = TestHive.conf.enableParallelGlobalLimit + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) + TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } + override def afterAll() { + TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, originalParallelGlobalLimit) + super.afterAll() + } // Column pruning tests From 45a1fcb729cd551fbba1c633af713152fd63c24c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 27 Feb 2017 06:05:28 +0000 Subject: [PATCH 02/18] Use Long for number of outputs. Turn to the approach of calculating number of rows to take in each partition. --- .../sort/BypassMergeSortShuffleWriter.java | 2 +- .../shuffle/sort/UnsafeShuffleWriter.java | 2 +- .../apache/spark/MapOutputStatistics.scala | 2 +- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../apache/spark/scheduler/MapStatus.scala | 24 ++++---- .../util/collection/ExternalSorter.scala | 4 +- .../apache/spark/sql/execution/limit.scala | 55 ++++++++++--------- .../execution/ExchangeCoordinatorSuite.scala | 6 +- 8 files changed, 49 insertions(+), 48 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 125ac685409d..400a8f6911ff 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -146,7 +146,7 @@ public void write(Iterator> records) throws IOException { // included in the shuffle write time. writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - int numOfRecords = 0; + long numOfRecords = 0; while (records.hasNext()) { final Product2 record = records.next(); final K key = record._1(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 05e3d0d28746..50b7229ad851 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -82,7 +82,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; - private int numOfRecords = 0; + private long numOfRecords = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index 2d47137f3ed7..8254bef80f7f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -28,4 +28,4 @@ package org.apache.spark private[spark] class MapOutputStatistics( val shuffleId: Int, val bytesByPartitionId: Array[Long], - val numberOfOutput: Array[Int]) + val numberOfOutput: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index b8e5cca10365..cb7c82d77c9d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -156,7 +156,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // Synchronize on the returned array because, on the driver, it gets mutated in place statuses.synchronized { val totalSizes = new Array[Long](dep.partitioner.numPartitions) - val numberOfOutput = new Array[Int](statuses.length) + val numberOfOutput = new Array[Long](statuses.length) statuses.zipWithIndex.map { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 364bb495805e..e0e591c223c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -41,13 +41,13 @@ private[spark] sealed trait MapStatus { */ def getSizeForBlock(reduceId: Int): Long - def numberOfOutput: Int + def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Int): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { if (uncompressedSizes.length > 2000) { HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { @@ -95,18 +95,18 @@ private[spark] object MapStatus { private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, private[this] var compressedSizes: Array[Byte], - private[this] var numOutput: Int) + private[this] var numOutput: Long) extends MapStatus with Externalizable { protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Int) { + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc - override def numberOfOutput: Int = numOutput + override def numberOfOutput: Long = numOutput override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) @@ -114,14 +114,14 @@ private[spark] class CompressedMapStatus( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) - out.writeInt(numOutput) + out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - numOutput = in.readInt() + numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -142,7 +142,7 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private[this] var numOutput: Int) + private[this] var numOutput: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -153,7 +153,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def location: BlockManagerId = loc - override def numberOfOutput: Int = numOutput + override def numberOfOutput: Long = numOutput override def getSizeForBlock(reduceId: Int): Long = { if (emptyBlocks.contains(reduceId)) { @@ -165,14 +165,14 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) - out.writeInt(numOutput) + out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - numOutput = in.readInt() + numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -183,7 +183,7 @@ private[spark] object HighlyCompressedMapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - numOutput: Int): HighlyCompressedMapStatus = { + numOutput: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6ad027e3ffb0..adb91dcfee63 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -176,11 +176,11 @@ private[spark] class ExternalSorter[K, V, C]( */ private[spark] def numSpills: Int = spills.size - def insertAll(records: Iterator[Product2[K, V]]): Int = { + def insertAll(records: Iterator[Product2[K, V]]): Long = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined - var numOfRecords: Int = 0 + var numOfRecords: Long = 0 if (shouldCombine) { // Combine values in-memory first using our AppendOnlyMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 7fd480c28f6f..6375e339de2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -113,7 +113,7 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { childRDD.getNumPartitions) val shuffleDependency = ShuffleExchange.prepareShuffleDependency( childRDD, child.output, partitioner, serializer) - val numberOfOutput: Seq[Int] = if (shuffleDependency.rdd.getNumPartitions != 0) { + val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { // submitMapStage does not accept RDD with 0 partition. // So, we will not submit this dependency. val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) @@ -133,15 +133,15 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { } else if (!respectChildParallelism) { // This is mainly for tests. // We take the rows of each partition until we reach the required limit number. - var countForRows = 0 + var numTakenRow = 0 val takeAmounts = new mutable.HashMap[Int, Int]() numberOfOutput.zipWithIndex.foreach { case (num, index) => - if (countForRows + num < limit) { - countForRows += num - takeAmounts += ((index, num)) + if (numTakenRow + num < limit) { + numTakenRow += num.toInt + takeAmounts += ((index, num.toInt)) } else { - val toTake = limit - countForRows - countForRows += toTake + val toTake = limit - numTakenRow + numTakenRow += toTake takeAmounts += ((index, toTake)) } } @@ -153,39 +153,40 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { } } else { // We try to distribute the required limit number of rows across all child rdd's partitions. - var numToReduce = (sumOfOutput - limit) - val reduceAmounts = new mutable.HashMap[Int, Int]() + var numTakenRow = 0 + val takeAmounts = new mutable.HashMap[Int, Int]() val nonEmptyParts = numberOfOutput.filter(_ > 0).size - val reducePerPart = numToReduce / nonEmptyParts + val takePerPart = limit / nonEmptyParts numberOfOutput.zipWithIndex.foreach { case (num, index) => - if (num >= reducePerPart) { - numToReduce -= reducePerPart - reduceAmounts += ((index, reducePerPart)) + if (num >= takePerPart) { + numTakenRow += takePerPart + takeAmounts += ((index, takePerPart)) } else { - numToReduce -= num - reduceAmounts += ((index, num)) + numTakenRow += num.toInt + takeAmounts += ((index, num.toInt)) } } - while (numToReduce > 0) { + var remainingRow = limit - numTakenRow + while (remainingRow > 0) { numberOfOutput.zipWithIndex.foreach { case (num, index) => - val toReduce = if (numToReduce / nonEmptyParts > 0) { - numToReduce / nonEmptyParts + val toTake = if (remainingRow / nonEmptyParts > 0) { + remainingRow / nonEmptyParts } else { - numToReduce + remainingRow } - if (num - reduceAmounts(index) >= toReduce) { - reduceAmounts(index) = reduceAmounts(index) + toReduce - numToReduce -= toReduce - } else if (num - reduceAmounts(index) > 0) { - reduceAmounts(index) = reduceAmounts(index) + 1 - numToReduce -= 1 + if (num - takeAmounts(index) >= toTake) { + takeAmounts(index) = takeAmounts(index) + toTake + remainingRow -= toTake + } else if (num - takeAmounts(index) > 0) { + takeAmounts(index) = takeAmounts(index) + 1 + remainingRow -= 1 } } } - val broadMap = sparkContext.broadcast(reduceAmounts) + val broadMap = sparkContext.broadcast(takeAmounts) shuffled.mapPartitionsWithIndexInternal { case (index, iter) => broadMap.value.get(index).map { size => - iter.drop(size) + iter.take(size) }.get } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 379d86157683..3d8f8fc2c38c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId, Array[Int](1)) + new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1, Array[Int](0)), - new MapOutputStatistics(1, bytesByPartitionId2, Array[Int](0))) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } From 1a5625222febf99bb89732d2cedd773923d5f6c6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Mar 2017 04:20:22 +0000 Subject: [PATCH 03/18] Rebased with latest change. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9004a4dd91f6..5ccc4dfffb8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -133,7 +133,7 @@ object SQLConf { .intConf .createWithDefault(4) - val ENABLE_PARALLEL_GLOBAL_LIMIT = SQLConfigBuilder("spark.sql.limit.globalparallel") + val ENABLE_PARALLEL_GLOBAL_LIMIT = buildConf("spark.sql.limit.globalparallel") .internal() .doc("Not to shuffle the results of local limit to one single partition in global limit " + "so that the limit operation doesn't downgrade parallelism. The config is mainly used " + From 2d37598d20f8a2c536eab4bd6a7e4da26f8bdc61 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 7 Mar 2017 09:02:50 +0000 Subject: [PATCH 04/18] Changed Limit outputs different results. It affects the test case output. --- .../sql-tests/results/subquery/in-subquery/in-limit.sql.out | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 71ca1f864947..e01d7386bbad 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -93,8 +93,6 @@ struct Date: Thu, 11 May 2017 13:18:04 +0000 Subject: [PATCH 05/18] Address comments. --- .../apache/spark/MapOutputStatistics.scala | 4 +- .../org/apache/spark/MapOutputTracker.scala | 8 +-- .../apache/spark/scheduler/MapStatus.scala | 3 + .../execution/exchange/ShuffleExchange.scala | 2 +- .../apache/spark/sql/execution/limit.scala | 65 +++++++++---------- 5 files changed, 41 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index 8254bef80f7f..8e8de3a97d3b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,9 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) - * @param numberOfOutput number of output for each pre-map output partition + * @param recordsByMapTask number of output records for each map task */ private[spark] class MapOutputStatistics( val shuffleId: Int, val bytesByPartitionId: Array[Long], - val numberOfOutput: Array[Long]) + val recordsByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 8dfeaa8098ec..1a358793c2ce 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -156,14 +156,14 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // Synchronize on the returned array because, on the driver, it gets mutated in place statuses.synchronized { val totalSizes = new Array[Long](dep.partitioner.numPartitions) - val numberOfOutput = new Array[Long](statuses.length) - statuses.zipWithIndex.map { case (s, index) => + val recordsByMapTask = new Array[Long](statuses.length) + statuses.zipWithIndex.foreach { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } - numberOfOutput(index) = s.numberOfOutput + recordsByMapTask(index) = s.numberOfOutput } - new MapOutputStatistics(dep.shuffleId, totalSizes, numberOfOutput) + new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index e0e591c223c9..5ba16caf3011 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -41,6 +41,9 @@ private[spark] sealed trait MapStatus { */ def getSizeForBlock(reduceId: Int): Long + /** + * The number of outputs for the map task. + */ def numberOfOutput: Long } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 3f5f7524004b..27703fef0d50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -221,7 +221,7 @@ object ShuffleExchange { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } - case LocalPartitioning(prev, numParts) => + case LocalPartitioning(_, numParts) => new Partitioner { override def numPartitions: Int = numParts override def getPartition(key: Any): Int = key.asInstanceOf[Int] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 6375e339de2c..7f90a99d558a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -117,7 +117,7 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { // submitMapStage does not accept RDD with 0 partition. // So, we will not submit this dependency. val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) - submittedStageFuture.get().numberOfOutput.toSeq + submittedStageFuture.get().recordsByPartitionId.toSeq } else { Nil } @@ -132,7 +132,10 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { shuffled } else if (!respectChildParallelism) { // This is mainly for tests. - // We take the rows of each partition until we reach the required limit number. + // Some tests like hive compatibility tests assume that the rows are returned by a specified + // order that the partitions are scaned sequentially until we reach the required number of + // rows. However, logically a Limit operator should not care the row scan order. + // Thus we take the rows of each partition until we reach the required limit number. var numTakenRow = 0 val takeAmounts = new mutable.HashMap[Int, Int]() numberOfOutput.zipWithIndex.foreach { case (num, index) => @@ -152,42 +155,36 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { }.get } } else { - // We try to distribute the required limit number of rows across all child rdd's partitions. - var numTakenRow = 0 - val takeAmounts = new mutable.HashMap[Int, Int]() - val nonEmptyParts = numberOfOutput.filter(_ > 0).size - val takePerPart = limit / nonEmptyParts - numberOfOutput.zipWithIndex.foreach { case (num, index) => - if (num >= takePerPart) { - numTakenRow += takePerPart - takeAmounts += ((index, takePerPart)) - } else { - numTakenRow += num.toInt - takeAmounts += ((index, num.toInt)) - } - } - var remainingRow = limit - numTakenRow - while (remainingRow > 0) { - numberOfOutput.zipWithIndex.foreach { case (num, index) => - val toTake = if (remainingRow / nonEmptyParts > 0) { - remainingRow / nonEmptyParts - } else { - remainingRow - } - if (num - takeAmounts(index) >= toTake) { - takeAmounts(index) = takeAmounts(index) + toTake - remainingRow -= toTake - } else if (num - takeAmounts(index) > 0) { - takeAmounts(index) = takeAmounts(index) + 1 - remainingRow -= 1 + // We try to evenly require the asked limit number of rows across all child rdd's partitions. + var rowsNeedToTake: Long = limit + val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) + val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) + + while (rowsNeedToTake > 0) { + val nonEmptyParts = remainingRowsByPartition.count(_ > 0) + // If the rows needed to take are less the number of non-empty partitions, take one row from + // each non-empty partitions until we reach `limit` rows. + // Otherwise, evenly divide the needed rows to each non-empty partitions. + val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) + remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => + // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during + // the traversal, so we need to add this check. + if (rowsNeedToTake > 0 && num > 0) { + if (num >= takePerPart) { + rowsNeedToTake -= takePerPart + takeAmountByPartition(index) += takePerPart + remainingRowsByPartition(index) -= takePerPart + } else { + rowsNeedToTake -= num + takeAmountByPartition(index) += num + remainingRowsByPartition(index) -= num + } } } } - val broadMap = sparkContext.broadcast(takeAmounts) + val broadMap = sparkContext.broadcast(takeAmountByPartition) shuffled.mapPartitionsWithIndexInternal { case (index, iter) => - broadMap.value.get(index).map { size => - iter.take(size) - }.get + iter.take(broadMap.value(index).toInt) } } } From e53648e7f58f439bb09a702521c2f84cf2e344bd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 31 Oct 2017 05:18:34 +0000 Subject: [PATCH 06/18] ShuffleExchange becomes ShuffleExchangeExec now. --- .../src/main/scala/org/apache/spark/sql/execution/limit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index a3b20b497b5d..f2ee6fe2b69a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -115,7 +115,7 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { val childRDD = child.execute() val partitioner = LocalPartitioning(child.outputPartitioning, childRDD.getNumPartitions) - val shuffleDependency = ShuffleExchange.prepareShuffleDependency( + val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( childRDD, child.output, partitioner, serializer) val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { // submitMapStage does not accept RDD with 0 partition. From a691e885b1f304ba4037964e2fba09c540503e1a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 7 May 2018 09:06:52 +0000 Subject: [PATCH 07/18] Fix merging conflict. --- .../test/scala/org/apache/spark/MapOutputTrackerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 775842f7b299..e79739692fe1 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), 1)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( From 5594bf9f13aa83d05a433bad0fd366daabd2d034 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 May 2018 09:06:03 +0000 Subject: [PATCH 08/18] Avoid evenly scanning partitions when child output has ordering. --- .../src/main/scala/org/apache/spark/sql/execution/limit.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 7705ecbe83c5..d8f3cc2a6e33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -126,7 +126,9 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { } // Try to keep child plan's original data parallelism or not. It is enabled by default. - val respectChildParallelism = sqlContext.conf.enableParallelGlobalLimit + // If child output has certain ordering, we can't evenly pick up rows from each parititon. + val respectChildParallelism = sqlContext.conf.enableParallelGlobalLimit && + child.outputOrdering != Nil val shuffled = new ShuffledRowRDD(shuffleDependency) From c9c8be6edde24ae1835a62e3a6a6e1d96265e1c5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 May 2018 10:09:39 +0000 Subject: [PATCH 09/18] Some refactoring. --- .../apache/spark/sql/internal/SQLConf.scala | 10 ++-- .../apache/spark/sql/execution/limit.scala | 16 +++---- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- .../TakeOrderedAndProjectSuite.scala | 48 +++++++++---------- .../execution/HiveCompatibilitySuite.scala | 6 +-- .../sql/hive/execution/PruningSuite.scala | 6 +-- 6 files changed, 39 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7334f9ce5cc9..c0cfed323b73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -186,12 +186,10 @@ object SQLConf { .intConf .createWithDefault(4) - val ENABLE_PARALLEL_GLOBAL_LIMIT = buildConf("spark.sql.limit.globalparallel") + val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") .internal() - .doc("Not to shuffle the results of local limit to one single partition in global limit " + - "so that the limit operation doesn't downgrade parallelism. The config is mainly used " + - "in tests especially Hive compatibility test cases which assume there is an order in " + - "the returned rows of limit operation.") + .doc("During global limit, try to evenly distribute limited rows across data " + + "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") .booleanConf .createWithDefault(true) @@ -1452,7 +1450,7 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) - def enableParallelGlobalLimit: Boolean = getConf(ENABLE_PARALLEL_GLOBAL_LIMIT) + def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index d8f3cc2a6e33..6102a06691d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -125,22 +125,18 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { Nil } - // Try to keep child plan's original data parallelism or not. It is enabled by default. - // If child output has certain ordering, we can't evenly pick up rows from each parititon. - val respectChildParallelism = sqlContext.conf.enableParallelGlobalLimit && - child.outputOrdering != Nil + // During global limit, try to evenly distribute limited rows across data + // partitions. If disabled, scanning data partitions sequentially until reaching limit number. + // Besides, if child output has certain ordering, we can't evenly pick up rows from + // each parititon. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && child.outputOrdering == Nil val shuffled = new ShuffledRowRDD(shuffleDependency) val sumOfOutput = numberOfOutput.sum if (sumOfOutput <= limit) { shuffled - } else if (!respectChildParallelism) { - // This is mainly for tests. - // Some tests like hive compatibility tests assume that the rows are returned by a specified - // order that the partitions are scaned sequentially until we reach the required number of - // rows. However, logically a Limit operator should not care the row scan order. - // Thus we take the rows of each partition until we reach the required limit number. + } else if (!flatGlobalLimit) { var numTakenRow = 0 val takeAmounts = new mutable.HashMap[Int, Int]() numberOfOutput.zipWithIndex.foreach { case (num, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index bd9c35d32878..7febcf340bdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -556,7 +556,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - withSQLConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT.key -> "false") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") val limit2Df = df.limit(2) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 5ac0ca59a863..2f9f79905458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -56,36 +56,32 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withSQLConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT.key -> "false") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) - } + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { - withSQLConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT.key -> "false") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) - } + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index fa090777ef3d..b9b2b7dbf38e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,7 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled - private val originalParallelGlobalLimit = TestHive.conf.enableParallelGlobalLimit + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -61,7 +61,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) // Ensure that limit operation returns rows in the same order as Hive - TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -76,7 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) - TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, originalParallelGlobalLimit) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 254b1eb8643b..16541295eb45 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -29,12 +29,12 @@ import org.apache.spark.sql.internal.SQLConf */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - private val originalParallelGlobalLimit = TestHive.conf.enableParallelGlobalLimit + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) - TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 @@ -42,7 +42,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { TestHive.reset() } override def afterAll() { - TestHive.setConf(SQLConf.ENABLE_PARALLEL_GLOBAL_LIMIT, originalParallelGlobalLimit) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) super.afterAll() } From ca00701ee1dbd18fe5b45f2fbfca80242da366f2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Jun 2018 07:34:47 +0000 Subject: [PATCH 10/18] Disable global limit optimization in limit sql query test. --- .../test/resources/sql-tests/inputs/limit.sql | 4 +- .../inputs/subquery/in-subquery/in-limit.sql | 5 +- .../resources/sql-tests/results/limit.sql.out | 78 ++++++++++--------- .../subquery/in-subquery/in-limit.sql.out | 60 +++++++------- 4 files changed, 84 insertions(+), 63 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index d3acf4d82289..c8ec9d9a192c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,3 +1,5 @@ +-- Disable global limit parallel +set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; @@ -21,7 +23,7 @@ SELECT * FROM testdata LIMIT true; SELECT * FROM testdata LIMIT 'a'; -- limit within a subquery -SELECT * FROM (SELECT * FROM range(10) ORDER BY id LIMIT 5) t WHERE id > 3; +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3; -- limit ALL SELECT * FROM testdata WHERE key < 3 LIMIT ALL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a40ee082ba3b..a862e0985b20 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,6 +1,9 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- Disable global limit optimization +set spark.sql.limit.flatGlobalLimit=false; + create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -97,4 +100,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; \ No newline at end of file +LIMIT 1; diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 187abc7a1f98..177cd845f488 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,63 +1,62 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 13 -- !query 0 -SELECT * FROM testdata LIMIT 2 +set spark.sql.limit.flatGlobalLimit=false -- !query 0 schema -struct +struct -- !query 0 output -1 1 -2 2 +spark.sql.limit.flatGlobalLimit false -- !query 1 -SELECT * FROM arraydata LIMIT 2 +SELECT * FROM testdata LIMIT 2 -- !query 1 schema -struct,nestedarraycol:array>> +struct -- !query 1 output -[1,2,3] [[1,2,3]] -[2,3,4] [[2,3,4]] +1 1 +2 2 -- !query 2 -SELECT * FROM mapdata LIMIT 2 +SELECT * FROM arraydata LIMIT 2 -- !query 2 schema -struct> +struct,nestedarraycol:array>> -- !query 2 output -{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} -{1:"a2",2:"b2",3:"c2",4:"d2"} +[1,2,3] [[1,2,3]] +[2,3,4] [[2,3,4]] -- !query 3 -SELECT * FROM testdata LIMIT 2 + 1 +SELECT * FROM mapdata LIMIT 2 -- !query 3 schema -struct +struct> -- !query 3 output -1 1 -2 2 -3 3 +{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} +{1:"a2",2:"b2",3:"c2",4:"d2"} -- !query 4 -SELECT * FROM testdata LIMIT CAST(1 AS int) +SELECT * FROM testdata LIMIT 2 + 1 -- !query 4 schema struct -- !query 4 output 1 1 +2 2 +3 3 -- !query 5 -SELECT * FROM testdata LIMIT -1 +SELECT * FROM testdata LIMIT CAST(1 AS int) -- !query 5 schema -struct<> +struct -- !query 5 output -org.apache.spark.sql.AnalysisException -The limit expression must be equal to or greater than 0, but got -1; +1 1 -- !query 6 -SELECT * FROM testData TABLESAMPLE (-1 ROWS) +SELECT * FROM testdata LIMIT -1 -- !query 6 schema struct<> -- !query 6 output @@ -66,44 +65,53 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -SELECT * FROM testdata LIMIT key > 3 +SELECT * FROM testData TABLESAMPLE (-1 ROWS) -- !query 7 schema struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +The limit expression must be equal to or greater than 0, but got -1; -- !query 8 -SELECT * FROM testdata LIMIT true +SELECT * FROM testdata LIMIT key > 3 -- !query 8 schema struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 9 -SELECT * FROM testdata LIMIT 'a' +SELECT * FROM testdata LIMIT true -- !query 9 schema struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must be integer type, but got boolean; -- !query 10 -SELECT * FROM (SELECT * FROM range(10) ORDER BY id LIMIT 5) t WHERE id > 3 +SELECT * FROM testdata LIMIT 'a' -- !query 10 schema -struct +struct<> -- !query 10 output -4 +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; -- !query 11 -SELECT * FROM testdata WHERE key < 3 LIMIT ALL +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 -- !query 11 schema -struct +struct -- !query 11 output +4 + + +-- !query 12 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 12 schema +struct +-- !query 12 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index e01d7386bbad..9eb5b3383e73 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,8 +1,16 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 +set spark.sql.limit.flatGlobalLimit=false +-- !query 0 schema +struct +-- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -17,13 +25,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 0 schema +-- !query 1 schema struct<> --- !query 0 output +-- !query 1 output --- !query 1 +-- !query 2 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -39,13 +47,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 1 schema +-- !query 2 schema struct<> --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -60,27 +68,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 2 schema +-- !query 3 schema struct<> --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 4 +-- !query 5 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -88,14 +96,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 +val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 +val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 6 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -106,31 +116,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 1 NULL --- !query 6 +-- !query 7 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 -val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 -val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 7 +-- !query 8 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -141,7 +149,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output 1 6 From 21b6948e31d3d23b93aee590def273d94187f221 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Jun 2018 08:53:17 +0000 Subject: [PATCH 11/18] Use array instead of map. --- .../apache/spark/sql/execution/limit.scala | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 6102a06691d7..c8e1b117565c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -137,23 +137,21 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { if (sumOfOutput <= limit) { shuffled } else if (!flatGlobalLimit) { - var numTakenRow = 0 - val takeAmounts = new mutable.HashMap[Int, Int]() + var numRowTaken = 0 + val takeAmounts = mutable.ArrayBuffer.fill[Long](numberOfOutput.length)(0L) numberOfOutput.zipWithIndex.foreach { case (num, index) => - if (numTakenRow + num < limit) { - numTakenRow += num.toInt - takeAmounts += ((index, num.toInt)) + if (numRowTaken + num < limit) { + numRowTaken += num.toInt + takeAmounts(index) += num.toInt } else { - val toTake = limit - numTakenRow - numTakenRow += toTake - takeAmounts += ((index, toTake)) + val toTake = limit - numRowTaken + numRowTaken += toTake + takeAmounts(index) += toTake } } val broadMap = sparkContext.broadcast(takeAmounts) shuffled.mapPartitionsWithIndexInternal { case (index, iter) => - broadMap.value.get(index).map { size => - iter.take(size) - }.get + iter.take(broadMap.value(index).toInt) } } else { // We try to evenly require the asked limit number of rows across all child rdd's partitions. From 1ff1fa56735722887c8eb0d6be42bf42fa580fa1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Jun 2018 09:25:18 +0000 Subject: [PATCH 12/18] Resolve merging issue. --- .../apache/spark/sql/catalyst/plans/physical/partitioning.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2294dc4c0013..1ea1c399db6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -213,7 +213,7 @@ case object SinglePartition extends Partitioning { */ case class LocalPartitioning(orgPartition: Partitioning, numPartitions: Int) extends Partitioning { // We will perform this partitioning no matter what the data distribution is. - override def satisfies(required: Distribution): Boolean = false + override def satisfies0(required: Distribution): Boolean = false } /** From a7375736b38b0955b47dc48e710114b3ae97234b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Jun 2018 12:51:30 +0000 Subject: [PATCH 13/18] Address comment. --- .../scala/org/apache/spark/sql/execution/limit.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index c8e1b117565c..398cdca92ba4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.mutable - import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow @@ -138,15 +136,14 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { shuffled } else if (!flatGlobalLimit) { var numRowTaken = 0 - val takeAmounts = mutable.ArrayBuffer.fill[Long](numberOfOutput.length)(0L) - numberOfOutput.zipWithIndex.foreach { case (num, index) => + val takeAmounts = numberOfOutput.map { num => if (numRowTaken + num < limit) { numRowTaken += num.toInt - takeAmounts(index) += num.toInt + num.toInt } else { val toTake = limit - numRowTaken numRowTaken += toTake - takeAmounts(index) += toTake + toTake } } val broadMap = sparkContext.broadcast(takeAmounts) From f24171e4364aa5d429dd76a22dc9991f7fff77f8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Jun 2018 00:37:42 +0000 Subject: [PATCH 14/18] Address comment. --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 82d4909d83b0..ed5b75d7974c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -254,9 +254,8 @@ object ShuffleExchangeExec { row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity case LocalPartitioning(_, _) => - (row: InternalRow) => { - TaskContext.get().partitionId() - } + val partitionId = TaskContext.get().partitionId() + _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } From 2d522b4d05cd544617bb7d78362aaf28bd15d62a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Jun 2018 12:13:02 +0000 Subject: [PATCH 15/18] Use orgPartition.numPartitions. --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 4 +++- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 6 +++--- .../main/scala/org/apache/spark/sql/execution/limit.scala | 3 +-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1ea1c399db6c..d0b1d6e49dbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -211,7 +211,9 @@ case object SinglePartition extends Partitioning { * of partitions are not changed and also the distribution of rows. This is mainly used to * obtain some statistics of map tasks such as number of outputs. */ -case class LocalPartitioning(orgPartition: Partitioning, numPartitions: Int) extends Partitioning { +case class LocalPartitioning(orgPartition: Partitioning) extends Partitioning { + val numPartitions = orgPartition.numPartitions + // We will perform this partitioning no matter what the data distribution is. override def satisfies0(required: Distribution): Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index ed5b75d7974c..86def862a485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,9 +231,9 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } - case LocalPartitioning(_, numParts) => + case l: LocalPartitioning => new Partitioner { - override def numPartitions: Int = numParts + override def numPartitions: Int = l.numPartitions override def getPartition(key: Any): Int = key.asInstanceOf[Int] } @@ -253,7 +253,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity - case LocalPartitioning(_, _) => + case _: LocalPartitioning => val partitionId = TaskContext.get().partitionId() _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 398cdca92ba4..0dae34d46f2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -110,8 +110,7 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { protected override def doExecute(): RDD[InternalRow] = { val childRDD = child.execute() - val partitioner = LocalPartitioning(child.outputPartitioning, - childRDD.getNumPartitions) + val partitioner = LocalPartitioning(child.outputPartitioning) val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( childRDD, child.output, partitioner, serializer) val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { From 97922200f9d898fe4fca1930f3bea8a3d12325e8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Jun 2018 23:39:54 +0000 Subject: [PATCH 16/18] Use childRDD. --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 6 ++++-- .../main/scala/org/apache/spark/sql/execution/limit.scala | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d0b1d6e49dbe..cd28c733f361 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -211,8 +213,8 @@ case object SinglePartition extends Partitioning { * of partitions are not changed and also the distribution of rows. This is mainly used to * obtain some statistics of map tasks such as number of outputs. */ -case class LocalPartitioning(orgPartition: Partitioning) extends Partitioning { - val numPartitions = orgPartition.numPartitions +case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { + val numPartitions = childRDD.getNumPartitions // We will perform this partitioning no matter what the data distribution is. override def satisfies0(required: Distribution): Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 0dae34d46f2b..392ca13724bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -110,7 +110,7 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { protected override def doExecute(): RDD[InternalRow] = { val childRDD = child.execute() - val partitioner = LocalPartitioning(child.outputPartitioning) + val partitioner = LocalPartitioning(childRDD) val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( childRDD, child.output, partitioner, serializer) val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { From 19d7d750357d32512d12b2b7520af2b46fbf145b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jun 2018 08:32:02 +0000 Subject: [PATCH 17/18] Use writeMetrics.recordsWritten. --- .../spark/shuffle/sort/BypassMergeSortShuffleWriter.java | 4 +--- .../org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 5 +---- .../main/scala/org/apache/spark/MapOutputStatistics.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleWriter.scala | 5 +++-- .../org/apache/spark/util/collection/ExternalSorter.scala | 7 +------ 5 files changed, 7 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0066471f9c58..e3bd5496cf5b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -145,12 +145,10 @@ public void write(Iterator> records) throws IOException { // included in the shuffle write time. writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - long numOfRecords = 0; while (records.hasNext()) { final Product2 record = records.next(); final K key = record._1(); partitionWriters[partitioner.getPartition(key)].write(key, record._2()); - numOfRecords += 1; } for (int i = 0; i < numPartitions; i++) { @@ -170,7 +168,7 @@ public void write(Iterator> records) throws IOException { } } mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, numOfRecords); + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f9c7c011d85f..069e6d5f224d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -87,8 +87,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; - private long numOfRecords = 0; - /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { MyByteArrayOutputStream(int size) { super(size); } @@ -188,7 +186,6 @@ public void write(scala.collection.Iterator> records) throws IOEx try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); - numOfRecords += 1; } closeAndWriteOutput(); success = true; @@ -252,7 +249,7 @@ void closeAndWriteOutput() throws IOException { } } mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, numOfRecords); + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index 8e8de3a97d3b..ff85e11409e3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,7 +23,7 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) - * @param recordsByMapTask number of output records for each map task + * @param recordsByPartitionId number of output records for each map output partition */ private[spark] class MapOutputStatistics( val shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index ffbc00b9d48a..91fc26762e53 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -59,7 +59,7 @@ private[spark] class SortShuffleWriter[K, V, C]( new ExternalSorter[K, V, V]( context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } - val numOfRecords = sorter.insertAll(records) + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, numOfRecords) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, + writeMetrics.recordsWritten) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index adb91dcfee63..176f84fa2a0d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -176,12 +176,10 @@ private[spark] class ExternalSorter[K, V, C]( */ private[spark] def numSpills: Int = spills.size - def insertAll(records: Iterator[Product2[K, V]]): Long = { + def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined - var numOfRecords: Long = 0 - if (shouldCombine) { // Combine values in-memory first using our AppendOnlyMap val mergeValue = aggregator.get.mergeValue @@ -195,7 +193,6 @@ private[spark] class ExternalSorter[K, V, C]( kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) - numOfRecords += 1 } } else { // Stick values into our buffer @@ -204,10 +201,8 @@ private[spark] class ExternalSorter[K, V, C]( val kv = records.next() buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) - numOfRecords += 1 } } - numOfRecords } /** From d05c144aecdd57f4ee3d179a240ccafa6c02bb66 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Jul 2018 04:44:00 +0000 Subject: [PATCH 18/18] Revert unused change. --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 1 - .../apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 86def862a485..50f10c31427d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -236,7 +236,6 @@ object ShuffleExchangeExec { override def numPartitions: Int = l.numPartitions override def getPartition(key: Any): Int = key.asInstanceOf[Int] } - case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 2f9f79905458..7e317a4d8026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -22,7 +22,6 @@ import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._