diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 833744f4777c..0ac3a7b891a6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -83,6 +83,11 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private final int numElementsForSpillThreshold; + /** + * Force this sorter to spill when the size in memory is beyond this threshold. + */ + private final long recordsSizeForSpillThreshold; + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -106,6 +111,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long pageCursor = -1; + private long inMemRecordsSize = 0; ShuffleExternalSorter( TaskMemoryManager memoryManager, @@ -127,6 +133,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + this.recordsSizeForSpillThreshold = + (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, (boolean) conf.get(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT())); @@ -316,6 +324,7 @@ private long freeMemory() { allocatedPages.clear(); currentPage = null; pageCursor = 0; + inMemRecordsSize = 0; return memoryFreed; } @@ -397,11 +406,14 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p // for tests assert(inMemSorter != null); if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { - logger.info("Spilling data because number of spilledRecords crossed the threshold " + - numElementsForSpillThreshold); + logger.info("Spilling data because number of spilledRecords ({}) crossed the threshold: {}", + inMemSorter.numRecords(), numElementsForSpillThreshold); + spill(); + } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) { + logger.info("Spilling data because size of spilledRecords ({}) crossed the threshold: {}", + inMemRecordsSize, recordsSizeForSpillThreshold); spill(); } - growPointerArrayIfNecessary(); final int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 or 8 bytes to store the record length. @@ -416,6 +428,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); + inMemRecordsSize += length; } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 55e4e609c3c7..d63f2fd0ab88 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -74,6 +74,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ private final int numElementsForSpillThreshold; + /** + * Force this sorter to spill when the size in memory is beyond this threshold. + */ + private final long maxRecordsSizeForSpillThreshold; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -86,6 +91,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { // These variables are reset after spilling: @Nullable private volatile UnsafeInMemorySorter inMemSorter; + private long inMemRecordsSize = 0; private MemoryBlock currentPage = null; private long pageCursor = -1; @@ -104,10 +110,12 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, - pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */); + serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, + pageSizeBytes, numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, + inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -124,10 +132,11 @@ public static UnsafeExternalSorter create( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, - numElementsForSpillThreshold, null, canUseRadixSort); + numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null, canUseRadixSort); } private UnsafeExternalSorter( @@ -140,6 +149,7 @@ private UnsafeExternalSorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); @@ -170,6 +180,7 @@ private UnsafeExternalSorter( } this.peakMemoryUsedBytes = getMemoryUsage(); this.numElementsForSpillThreshold = numElementsForSpillThreshold; + this.maxRecordsSizeForSpillThreshold = maxRecordsSizeForSpillThreshold; // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator @@ -228,7 +239,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the // records. Otherwise, if the task is over allocated memory, then without freeing the memory // pages, we might not be able to get memory for the pointer array. - + inMemRecordsSize = 0; taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; @@ -396,8 +407,11 @@ public void insertRecord( logger.info("Spilling data because number of spilledRecords crossed the threshold " + numElementsForSpillThreshold); spill(); + } else if (inMemRecordsSize >= maxRecordsSizeForSpillThreshold) { + logger.info("Spilling data because size of spilledRecords crossed the threshold " + + maxRecordsSizeForSpillThreshold); + spill(); } - growPointerArrayIfNecessary(); int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 or 8 bytes to store the record length. @@ -411,6 +425,7 @@ public void insertRecord( Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + inMemRecordsSize += length; } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 652db2bdf90a..71ad2ecc2e7a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -998,6 +998,26 @@ package object config { .intConf .createWithDefault(Integer.MAX_VALUE) + private[spark] val SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.map.maxRecordsSizeForSpillThreshold") + .internal() + .doc("The maximum size in memory before forcing the map-side shuffle sorter to spill. " + + "By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Long.MaxValue) + + private[spark] val SHUFFLE_SPILL_REDUCE_MAX_SIZE_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.reduce.maxRecordsSizeForSpillThreshold") + .internal() + .doc("The maximum size in memory before forcing the reduce-side to spill. " + + "By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Long.MaxValue) + private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD = ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold") .internal() diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 1983b0002853..e6df4ef799ae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -53,6 +53,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) private[this] val initialMemoryThreshold: Long = SparkEnv.get.conf.get(SHUFFLE_SPILL_INITIAL_MEM_THRESHOLD) + // Force this collection to spill when its size is greater than this threshold + private[this] val maxSizeForceSpillThreshold: Long = + SparkEnv.get.conf.get(SHUFFLE_SPILL_REDUCE_MAX_SIZE_FORCE_SPILL_THRESHOLD) + // Force this collection to spill when there are this many elements in memory // For testing only private[this] val numElementsForceSpillThreshold: Int = @@ -81,7 +85,11 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { var shouldSpill = false - if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { + // Check number of elements or memory usage limits, whichever is hit first + if (_elementsRead > numElementsForceSpillThreshold + || currentMemory > maxSizeForceSpillThreshold) { + shouldSpill = true + } else if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = acquireMemory(amountToRequest) @@ -90,11 +98,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // or we already had more memory than myMemoryThreshold), spill the current collection shouldSpill = currentMemory >= myMemoryThreshold } - shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold // Actually spill if (shouldSpill) { _spillCount += 1 - logSpillage(currentMemory) + logSpillage(currentMemory, elementsRead) spill(collection) _elementsRead = 0 _memoryBytesSpilled += currentMemory @@ -141,10 +148,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * * @param size number of bytes spilled */ - @inline private def logSpillage(size: Long): Unit = { + @inline private def logSpillage(size: Long, elements: Int) { val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)" - .format(threadId, org.apache.spark.util.Utils.bytesToString(size), + logInfo("Thread %d spilling in-memory map of %s (elements: %d) to disk (%d time%s so far)" + .format(threadId, org.apache.spark.util.Utils.bytesToString(size), elements, _spillCount, if (_spillCount > 1) "s" else "")) } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 43977717f6c9..0b202f98d458 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -91,9 +91,12 @@ public int compare( private final long pageSizeBytes = conf.getSizeAsBytes( package$.MODULE$.BUFFER_PAGESIZE().key(), "4m"); - private final int spillThreshold = + private final int spillElementsThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + private final long spillSizeThreshold = + (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()); + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -167,7 +170,8 @@ private UnsafeExternalSorter newSorter() throws IOException { prefixComparator, /* initialSize */ 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); } @@ -394,7 +398,8 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; @@ -456,7 +461,8 @@ public void testPeakMemoryUsed() throws Exception { prefixComparator, 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 700f6b773727..ebfca1f88708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1516,6 +1516,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by window operator") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(SHUFFLE_SPILL_REDUCE_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() @@ -1531,6 +1538,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by sort merge join operator") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get) + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") .internal() @@ -1546,6 +1560,15 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by cartesian product operator") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + " as regular expressions.") @@ -2647,18 +2670,26 @@ class SQLConf extends Serializable with Logging { def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def windowExecBufferSpillSizeThreshold: Long = getConf(WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def sortMergeJoinExecBufferInMemoryThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecBufferSpillSizeThreshold: Long = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def cartesianProductExecBufferInMemoryThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def cartesianProductExecBufferSizeSpillThreshold: Long = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC) def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 90b55a8586de..bd5dbc82b07f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -120,6 +120,8 @@ private UnsafeExternalRowSorter( pageSizeBytes, (int) SparkEnv.get().conf().get( package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + (long) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()), canUseRadixSort ); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 117e98f33a0e..ffd4e67fe541 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -247,6 +247,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti map.getPageSizeBytes(), (int) SparkEnv.get().conf().get( package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + (long) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index acd54fe25d62..d0ac80da6a69 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -59,9 +59,10 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - int numElementsForSpillThreshold) throws IOException { + int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, null); + numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null); } public UnsafeKVExternalSorter( @@ -71,6 +72,7 @@ public UnsafeKVExternalSorter( SerializerManager serializerManager, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -97,6 +99,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, + maxRecordsSizeForSpillThreshold, canUseRadixSort); } else { // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow @@ -163,6 +166,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, + maxRecordsSizeForSpillThreshold, inMemSorter); // reset the map, so we can re-use it to insert new records. the inMemSorter will not used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index ac282ea2e94f..1c016d4f8241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -50,9 +50,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) extends Logging { + numRowsSpillThreshold: Int, + maxSizeSpillThreshold: Long) extends Logging { - def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { + def this(numRowsInMemoryBufferThreshold: Int, + numRowsSpillThreshold: Int, + maxSizeSpillThreshold: Long) { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -61,7 +64,8 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( 1024, SparkEnv.get.memoryManager.pageSizeBytes, numRowsInMemoryBufferThreshold, - numRowsSpillThreshold) + numRowsSpillThreshold, + maxSizeSpillThreshold) } private val initialSizeOfInMemoryBuffer = @@ -122,6 +126,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize, pageSizeBytes, numRowsSpillThreshold, + maxSizeSpillThreshold, false) // populate with existing in-memory buffered rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 75651500954c..b3e8cdcf2408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -316,6 +316,7 @@ class SortBasedAggregator( SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index b5372bcca89d..31676473ec8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -75,6 +75,7 @@ class ObjectAggregationMap() { SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD), null ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 29645a736548..a6476aa4b80c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -36,11 +36,13 @@ class UnsafeCartesianRDD( right : RDD[UnsafeRow], numFieldsOfRight: Int, inMemoryBufferThreshold: Int, - spillThreshold: Int) + spillThreshold: Int, + spillSizeThreshold: Long) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold, + spillSizeThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) @@ -88,7 +90,8 @@ case class CartesianProductExec( rightResults, right.output.size, sqlContext.conf.cartesianProductExecBufferInMemoryThreshold, - sqlContext.conf.cartesianProductExecBufferSpillThreshold) + sqlContext.conf.cartesianProductExecBufferSpillThreshold, + sqlContext.conf.cartesianProductExecBufferSizeSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 62eea611556f..59197c6740bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -171,6 +171,10 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferSpillThreshold } + private def getSpillSizeThreshold: Long = { + sqlContext.conf.sortMergeJoinExecBufferSpillSizeThreshold + } + private def getInMemoryThreshold: Int = { sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold } @@ -178,6 +182,7 @@ case class SortMergeJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold + val spillSizeThreshold = getSpillSizeThreshold val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { @@ -206,6 +211,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -251,6 +257,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) val rightNullRow = new GenericInternalRow(right.output.length) @@ -266,6 +273,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) val leftNullRow = new GenericInternalRow(left.output.length) @@ -301,6 +309,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -337,6 +346,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -380,6 +390,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -712,6 +723,7 @@ private[joins] class SortMergeJoinScanner( bufferedIter: RowIterator, inMemoryThreshold: Int, spillThreshold: Int, + spillSizeThreshold: Long, eagerCleanupResources: () => Unit) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ @@ -724,7 +736,7 @@ private[joins] class SortMergeJoinScanner( private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, spillSizeThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index f54c4b8f2206..7102456ca1b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -199,6 +199,7 @@ case class WindowInPandasExec( val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold val spillThreshold = conf.windowExecBufferSpillThreshold + val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold val sessionLocalTimeZone = conf.sessionLocalTimeZone // Extract window expressions and window functions @@ -318,7 +319,8 @@ case class WindowInPandasExec( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) var bufferIterator: Iterator[UnsafeRow] = _ val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index d191f3790ffa..18efae24e96a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -110,6 +110,7 @@ case class WindowExec( val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold + val spillSizeThreshold = sqlContext.conf.windowExecBufferSpillSizeThreshold // Start processing. child.execute().mapPartitions { stream => @@ -137,7 +138,8 @@ case class WindowExec( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) var bufferIterator: Iterator[UnsafeRow] = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 0869e25674e6..6f66104b2c82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -103,7 +103,8 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { for (_ <- 0L until iterations) { val array = new ExternalAppendOnlyUnsafeRowArray( ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, - numSpillThreshold) + numSpillThreshold, + Long.MaxValue) rows.foreach(x => array.add(x)) @@ -142,6 +143,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { 1024, SparkEnv.get.memoryManager.pageSizeBytes, numSpillThreshold, + Long.MaxValue, false) rows.foreach(x => @@ -166,7 +168,9 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, + numSpillThreshold, + Long.MaxValue) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index b29de9c4adba..fc255426c269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -50,7 +50,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar 1024, SparkEnv.get.memoryManager.pageSizeBytes, inMemoryThreshold, - spillThreshold) + spillThreshold, + Long.MaxValue) try f(array) finally { array.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 8aa003a3dfeb..48560311f0ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -127,7 +127,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get, + SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get + ) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => @@ -240,6 +242,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession sparkContext.env.serializerManager, taskMemoryManager.pageSizeBytes(), Int.MaxValue, + Long.MaxValue, map) } finally { TaskContext.unset()