Skip to content

Commit 82e21c1

Browse files
committed
Force spilling in UnsafeExternalSortSuite.
1 parent 88b72db commit 82e21c1

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444

4545
final class UnsafeExternalRowSorter {
4646

47+
/**
48+
* If positive, forces records to be spilled to disk at the given frequency (measured in numbers
49+
* of records). This is only intended to be used in tests.
50+
*/
51+
private int testSpillFrequency = 0;
52+
53+
private long numRowsInserted = 0;
54+
4755
private final StructType schema;
4856
private final UnsafeRowConverter rowConverter;
4957
private final PrefixComputer prefixComputer;
@@ -77,6 +85,15 @@ public UnsafeExternalRowSorter(
7785
);
7886
}
7987

88+
/**
89+
* Forces spills to occur every `frequency` records. Only for use in tests.
90+
*/
91+
@VisibleForTesting
92+
void setTestSpillFrequency(int frequency) {
93+
assert frequency > 0 : "Frequency must be positive";
94+
testSpillFrequency = frequency;
95+
}
96+
8097
@VisibleForTesting
8198
void insertRow(InternalRow row) throws IOException {
8299
final int sizeRequirement = rowConverter.getSizeRequirement(row);
@@ -99,6 +116,10 @@ void insertRow(InternalRow row) throws IOException {
99116
sizeRequirement,
100117
prefix
101118
);
119+
numRowsInserted++;
120+
if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
121+
spill();
122+
}
102123
}
103124

104125
@VisibleForTesting

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,15 @@ case class ExternalSort(
253253
*
254254
* @param global when true performs a global sort of all partitions by shuffling the data first
255255
* if necessary.
256+
* @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
257+
* spill every `frequency` records.
256258
*/
257259
@DeveloperApi
258260
case class UnsafeExternalSort(
259261
sortOrder: Seq[SortOrder],
260262
global: Boolean,
261-
child: SparkPlan)
263+
child: SparkPlan,
264+
testSpillFrequency: Int = 0)
262265
extends UnaryNode {
263266

264267
private[this] val schema: StructType = child.schema
@@ -278,7 +281,11 @@ case class UnsafeExternalSort(
278281
override def computePrefix(row: InternalRow): Long = prefixComputer(row)
279282
}
280283
}
281-
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
284+
val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
285+
if (testSpillFrequency > 0) {
286+
sorter.setTestSpillFrequency(testSpillFrequency)
287+
}
288+
sorter.sort(iterator)
282289
}
283290
child.execute().mapPartitions(doSort, preservesPartitioning = true)
284291
}

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
5454
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
5555
StructType(StructField("a", dataType, nullable = true) :: Nil)
5656
)
57+
assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
5758
checkAnswer(
5859
inputDf,
59-
UnsafeExternalSort(sortOrder, global = false, _: SparkPlan),
60+
UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 100),
6061
Sort(sortOrder, global = false, _: SparkPlan)
6162
)
6263
}

0 commit comments

Comments
 (0)