Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ private[spark] object LogKeys {
case object NUM_DRIVERS extends LogKey
case object NUM_DROPPED_PARTITIONS extends LogKey
case object NUM_EFFECTIVE_RULE_OF_RUNS extends LogKey
case object NUM_ELEMENTS_SPILL_RECORDS extends LogKey
case object NUM_ELEMENTS_SPILL_THRESHOLD extends LogKey
case object NUM_EVENTS extends LogKey
case object NUM_EXAMPLES extends LogKey
Expand Down Expand Up @@ -768,6 +769,8 @@ private[spark] object LogKeys {
case object SPARK_REPO_URL extends LogKey
case object SPARK_REVISION extends LogKey
case object SPARK_VERSION extends LogKey
case object SPILL_RECORDS_SIZE extends LogKey
case object SPILL_RECORDS_SIZE_THRESHOLD extends LogKey
case object SPILL_TIMES extends LogKey
case object SQL_TEXT extends LogKey
case object SRC_PATH extends LogKey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck
*/
private final int numElementsForSpillThreshold;

/**
* Force this sorter to spill when the size in memory is beyond this threshold.
*/
private final long recordsSizeForSpillThreshold;

/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;

Expand All @@ -112,6 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck
@Nullable private ShuffleInMemorySorter inMemSorter;
@Nullable private MemoryBlock currentPage = null;
private long pageCursor = -1;
private long inMemRecordsSize = 0;

// Checksum calculator for each partition. Empty when shuffle checksum disabled.
private final Checksum[] partitionChecksums;
Expand All @@ -136,6 +142,8 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck
(int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.numElementsForSpillThreshold =
(int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD());
this.recordsSizeForSpillThreshold =
(long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD());
this.writeMetrics = writeMetrics;
this.inMemSorter = new ShuffleInMemorySorter(
this, initialSize, (boolean) conf.get(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT()));
Expand Down Expand Up @@ -338,6 +346,7 @@ private long freeMemory() {
allocatedPages.clear();
currentPage = null;
pageCursor = 0;
inMemRecordsSize = 0;
return memoryFreed;
}

Expand Down Expand Up @@ -417,12 +426,17 @@ private void acquireNewPageIfNecessary(int required) {
public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
throws IOException {

// for tests
assert(inMemSorter != null);
if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
logger.info("Spilling data because number of spilledRecords crossed the threshold {}" +
logger.info("Spilling data because number of spilledRecords ({}) crossed the threshold {}",
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS$.MODULE$, inMemSorter.numRecords()),
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD$.MODULE$, numElementsForSpillThreshold));
spill();
} else if (inMemRecordsSize >= recordsSizeForSpillThreshold) {
logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}",
MDC.of(LogKeys.SPILL_RECORDS_SIZE$.MODULE$, inMemRecordsSize),
MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD$.MODULE$, recordsSizeForSpillThreshold));
spill();
}

growPointerArrayIfNecessary();
Expand All @@ -439,6 +453,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
inMemSorter.insertRecord(recordAddress, partitionId);
inMemRecordsSize += required;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
*/
private final int numElementsForSpillThreshold;

/**
* Force this sorter to spill when the size in memory is beyond this threshold.
*/
private final long recordsSizeForSpillThreshold;

/**
* Memory pages that hold the records being sorted. The pages in this list are freed when
* spilling, although in principle we could recycle these pages across spills (on the other hand,
Expand All @@ -92,6 +97,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {

// These variables are reset after spilling:
@Nullable private volatile UnsafeInMemorySorter inMemSorter;
private long inMemRecordsSize = 0;

private MemoryBlock currentPage = null;
private long pageCursor = -1;
Expand All @@ -110,11 +116,13 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
int initialSize,
long pageSizeBytes,
int numElementsForSpillThreshold,
long recordsSizeForSpillThreshold,
UnsafeInMemorySorter inMemorySorter,
long existingMemoryConsumption) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize,
pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */);
pageSizeBytes, numElementsForSpillThreshold, recordsSizeForSpillThreshold,
inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption);
sorter.totalSpillBytes += existingMemoryConsumption;
Expand All @@ -133,10 +141,11 @@ public static UnsafeExternalSorter create(
int initialSize,
long pageSizeBytes,
int numElementsForSpillThreshold,
long recordsSizeForSpillThreshold,
boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes,
numElementsForSpillThreshold, null, canUseRadixSort);
numElementsForSpillThreshold, recordsSizeForSpillThreshold, null, canUseRadixSort);
}

private UnsafeExternalSorter(
Expand All @@ -149,6 +158,7 @@ private UnsafeExternalSorter(
int initialSize,
long pageSizeBytes,
int numElementsForSpillThreshold,
long recordsSizeForSpillThreshold,
@Nullable UnsafeInMemorySorter existingInMemorySorter,
boolean canUseRadixSort) {
super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
Expand Down Expand Up @@ -178,6 +188,7 @@ private UnsafeExternalSorter(
this.inMemSorter = existingInMemorySorter;
}
this.peakMemoryUsedBytes = getMemoryUsage();
this.recordsSizeForSpillThreshold = recordsSizeForSpillThreshold;
this.numElementsForSpillThreshold = numElementsForSpillThreshold;

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
Expand Down Expand Up @@ -238,6 +249,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
// pages will currently be counted as memory spilled even though that space isn't actually
// written to disk. This also counts the space needed to store the sorter's pointer array.
inMemSorter.freeMemory();
inMemRecordsSize = 0;
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
// pages, we might not be able to get memory for the pointer array.
Expand Down Expand Up @@ -480,9 +492,15 @@ public void insertRecord(

assert(inMemSorter != null);
if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
logger.info("Spilling data because number of spilledRecords crossed the threshold {}",
logger.info("Spilling data because number of spilledRecords ({}) crossed the threshold {}",
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS$.MODULE$, inMemSorter.numRecords()),
MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD$.MODULE$, numElementsForSpillThreshold));
spill();
} else if (inMemRecordsSize >= recordsSizeForSpillThreshold) {
logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}",
MDC.of(LogKeys.SPILL_RECORDS_SIZE$.MODULE$, inMemRecordsSize),
MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD$.MODULE$, recordsSizeForSpillThreshold));
spill();
}

final int uaoSize = UnsafeAlignedOffset.getUaoSize();
Expand All @@ -497,6 +515,7 @@ public void insertRecord(
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
inMemRecordsSize += required;
}

/**
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,18 @@ package object config {
.intConf
.createWithDefault(Integer.MAX_VALUE)

private[spark] val SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD =
ConfigBuilder("spark.shuffle.spill.maxRecordsSizeForSpillThreshold")
.internal()
.doc("The maximum size in memory before forcing the shuffle sorter to spill. " +
"By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " +
"until we reach some limitations, like the max page size limitation for the pointer " +
"array in the sorter.")
.version("4.1.0")
.bytesConf(ByteUnit.BYTE)
.checkValue(v => v > 0, "The threshold should be positive.")
.createWithDefault(Long.MaxValue)

private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD =
ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
private[this] val numElementsForceSpillThreshold: Int =
SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD)

// Force this collection to spill when its size is greater than this threshold
private[this] val maxSizeForceSpillThreshold: Long =
SparkEnv.get.conf.get(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD)

// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
@volatile private[this] var myMemoryThreshold = initialMemoryThreshold
Expand All @@ -80,21 +84,25 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
* @return true if `collection` was spilled to disk; false otherwise
*/
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
val shouldSpill = if (_elementsRead > numElementsForceSpillThreshold
|| currentMemory > maxSizeForceSpillThreshold) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By moving _elementsRead > numElementsForceSpillThreshold here, we would actually reduce some unnecessary allocations .... nice !

// Check number of elements or memory usage limits, whichever is hit first
true
} else if (_elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
shouldSpill = currentMemory >= myMemoryThreshold
currentMemory >= myMemoryThreshold
} else {
false
}
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
logSpillage(currentMemory, _elementsRead)
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
Expand Down Expand Up @@ -140,12 +148,14 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
* Prints a standard log message detailing spillage.
*
* @param size number of bytes spilled
* @param elements number of elements read from input since last spill
*/
@inline private def logSpillage(size: Long): Unit = {
@inline private def logSpillage(size: Long, elements: Int): Unit = {
val threadId = Thread.currentThread().getId
logInfo(log"Thread ${MDC(LogKeys.THREAD_ID, threadId)} " +
log"spilling in-memory map of ${MDC(LogKeys.BYTE_SIZE,
org.apache.spark.util.Utils.bytesToString(size))} to disk " +
org.apache.spark.util.Utils.bytesToString(size))} " +
log"(elements: ${MDC(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, elements)}) to disk " +
log"(${MDC(LogKeys.NUM_SPILLS, _spillCount)} times so far)")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ public int compare(
private final long pageSizeBytes = conf.getSizeAsBytes(
package$.MODULE$.BUFFER_PAGESIZE().key(), "4m");

private final int spillThreshold =
private final int spillElementsThreshold =
(int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD());

private final long spillSizeThreshold =
(long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD());


@BeforeEach
public void setUp() throws Exception {
MockitoAnnotations.openMocks(this).close();
Expand Down Expand Up @@ -163,7 +167,8 @@ private UnsafeExternalSorter newSorter() throws IOException {
prefixComparator,
/* initialSize */ 1024,
pageSizeBytes,
spillThreshold,
spillElementsThreshold,
spillSizeThreshold,
shouldUseRadixSort());
}

Expand Down Expand Up @@ -453,7 +458,8 @@ public void forcedSpillingWithoutComparator() throws Exception {
null,
/* initialSize */ 1024,
pageSizeBytes,
spillThreshold,
spillElementsThreshold,
spillSizeThreshold,
shouldUseRadixSort());
long[] record = new long[100];
int recordSize = record.length * 8;
Expand Down Expand Up @@ -515,7 +521,8 @@ public void testPeakMemoryUsed() throws Exception {
prefixComparator,
1024,
pageSizeBytes,
spillThreshold,
spillElementsThreshold,
spillSizeThreshold,
shouldUseRadixSort());

// Peak memory should be monotonically increasing. More specifically, every time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3356,6 +3356,13 @@ object SQLConf {
.intConf
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)

val WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD =
buildConf("spark.sql.windowExec.buffer.spill.size.threshold")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config name is a bit confusing.
spark.sql.windowExec.buffer.spill.threshold vs spark.sql.windowExec.buffer.spill.size.threshold.

Same for the others introduced.

I will let @HyukjinKwon or @cloud-fan comment better though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super used to this area. I would rarther follow the suggestions from you / others.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @HyukjinKwon !
+CC @dongjoon-hyun as well.

.internal()
.doc("Threshold for size of rows to be spilled by window operator")
.version("4.1.0")
.fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD)

val WINDOW_GROUP_LIMIT_THRESHOLD =
buildConf("spark.sql.optimizer.windowGroupLimitThreshold")
.internal()
Expand All @@ -3377,6 +3384,15 @@ object SQLConf {
.intConf
.createWithDefault(4096)

val SESSION_WINDOW_BUFFER_SPILL_SIZE_THRESHOLD =
buildConf("spark.sql.sessionWindow.buffer.spill.size.threshold")
.internal()
.doc("Threshold for size of rows to be spilled by window operator. Note that " +
"the buffer is used only for the query Spark cannot apply aggregations on determining " +
"session window.")
.version("4.1.0")
.fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD)

val SESSION_WINDOW_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.sessionWindow.buffer.spill.threshold")
.internal()
Expand Down Expand Up @@ -3420,6 +3436,13 @@ object SQLConf {
.intConf
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)

val SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD =
buildConf("spark.sql.sortMergeJoinExec.buffer.spill.size.threshold")
.internal()
.doc("Threshold for size of rows to be spilled by sort merge join operator")
.version("4.1.0")
.fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD)

val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold")
.internal()
Expand All @@ -3437,6 +3460,13 @@ object SQLConf {
.intConf
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)

val CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD =
buildConf("spark.sql.cartesianProductExec.buffer.spill.size.threshold")
.internal()
.doc("Threshold for size of rows to be spilled by cartesian product operator")
.version("4.1.0")
.fallbackConf(SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD)

val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames")
.doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" +
" as regular expressions.")
Expand Down Expand Up @@ -6679,24 +6709,35 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)

def windowExecBufferSpillSizeThreshold: Long = getConf(WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD)

def windowGroupLimitThreshold: Int = getConf(WINDOW_GROUP_LIMIT_THRESHOLD)

def sessionWindowBufferInMemoryThreshold: Int = getConf(SESSION_WINDOW_BUFFER_IN_MEMORY_THRESHOLD)

def sessionWindowBufferSpillThreshold: Int = getConf(SESSION_WINDOW_BUFFER_SPILL_THRESHOLD)

def sessionWindowBufferSpillSizeThreshold: Long =
getConf(SESSION_WINDOW_BUFFER_SPILL_SIZE_THRESHOLD)

def sortMergeJoinExecBufferInMemoryThreshold: Int =
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD)

def sortMergeJoinExecBufferSpillThreshold: Int =
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD)

def sortMergeJoinExecBufferSpillSizeThreshold: Long =
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD)

def cartesianProductExecBufferInMemoryThreshold: Int =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD)

def cartesianProductExecBufferSpillThreshold: Int =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)

def cartesianProductExecBufferSizeSpillThreshold: Long =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD)

def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC)

def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
Expand Down
Loading