Skip to content

Commit 7f875f9

Browse files
committed
Commit failing test demonstrating bug in handling objects in spills
1 parent 41b8881 commit 7f875f9

File tree

5 files changed

+116
-59
lines changed

5 files changed

+116
-59
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ private void initializeForWriting() throws IOException {
118118
* Sort and spill the current records in response to memory pressure.
119119
*/
120120
@VisibleForTesting
121-
void spill() throws IOException {
121+
public void spill() throws IOException {
122122
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
123123
Thread.currentThread().getId(),
124124
Utils.bytesToString(getMemoryUsage()),

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import static org.mockito.Answers.RETURNS_SMART_NULLS;
3939
import static org.mockito.Mockito.*;
4040

41-
import org.apache.spark.HashPartitioner;
4241
import org.apache.spark.SparkConf;
4342
import org.apache.spark.TaskContext;
4443
import org.apache.spark.executor.ShuffleWriteMetrics;
@@ -56,8 +55,6 @@ public class UnsafeExternalSorterSuite {
5655

5756
final TaskMemoryManager memoryManager =
5857
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
59-
// Compute key prefixes based on the records' partition ids
60-
final HashPartitioner hashPartitioner = new HashPartitioner(4);
6158
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
6259
final PrefixComparator prefixComparator = new PrefixComparator() {
6360
@Override
@@ -138,11 +135,8 @@ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws
138135
sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
139136
}
140137

141-
/**
142-
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
143-
*/
144138
@Test
145-
public void testSortingOnlyByPartitionId() throws Exception {
139+
public void testSortingOnlyByPrefix() throws Exception {
146140

147141
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
148142
memoryManager,

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
import scala.collection.Iterator;
2525
import scala.math.Ordering;
2626

27+
import com.google.common.annotations.VisibleForTesting;
28+
2729
import org.apache.spark.SparkEnv;
2830
import org.apache.spark.TaskContext;
2931
import org.apache.spark.sql.AbstractScalaRowIterator;
3032
import org.apache.spark.sql.catalyst.InternalRow;
3133
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
3234
import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
35+
import org.apache.spark.sql.catalyst.util.ObjectPool;
3336
import org.apache.spark.sql.types.StructType;
3437
import org.apache.spark.unsafe.PlatformDependent;
3538
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
@@ -41,61 +44,70 @@ final class UnsafeExternalRowSorter {
4144

4245
private final StructType schema;
4346
private final UnsafeRowConverter rowConverter;
44-
private final RowComparator rowComparator;
45-
private final PrefixComparator prefixComparator;
4647
private final Function1<InternalRow, Long> prefixComputer;
48+
private final ObjectPool objPool = new ObjectPool(128);
49+
private final UnsafeExternalSorter sorter;
50+
private byte[] rowConversionBuffer = new byte[1024 * 8];
4751

4852
public UnsafeExternalRowSorter(
4953
StructType schema,
5054
Ordering<InternalRow> ordering,
5155
PrefixComparator prefixComparator,
5256
// TODO: if possible, avoid this boxing of the return value
53-
Function1<InternalRow, Long> prefixComputer) {
57+
Function1<InternalRow, Long> prefixComputer) throws IOException {
5458
this.schema = schema;
5559
this.rowConverter = new UnsafeRowConverter(schema);
56-
this.rowComparator = new RowComparator(ordering, schema);
57-
this.prefixComparator = prefixComparator;
5860
this.prefixComputer = prefixComputer;
59-
}
60-
61-
public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
6261
final SparkEnv sparkEnv = SparkEnv.get();
6362
final TaskContext taskContext = TaskContext.get();
64-
byte[] rowConversionBuffer = new byte[1024 * 8];
65-
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
63+
sorter = new UnsafeExternalSorter(
6664
taskContext.taskMemoryManager(),
6765
sparkEnv.shuffleMemoryManager(),
6866
sparkEnv.blockManager(),
6967
taskContext,
70-
rowComparator,
68+
new RowComparator(ordering, schema.length(), objPool),
7169
prefixComparator,
7270
4096,
7371
sparkEnv.conf()
7472
);
73+
}
74+
75+
@VisibleForTesting
76+
void insertRow(InternalRow row) throws IOException {
77+
final int sizeRequirement = rowConverter.getSizeRequirement(row);
78+
if (sizeRequirement > rowConversionBuffer.length) {
79+
rowConversionBuffer = new byte[sizeRequirement];
80+
} else {
81+
// Zero out the buffer that's used to hold the current row. This is necessary in order
82+
// to ensure that rows hash properly, since garbage data from the previous row could
83+
// otherwise end up as padding in this row. As a performance optimization, we only zero
84+
// out the portion of the buffer that we'll actually write to.
85+
Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0);
86+
}
87+
final int bytesWritten = rowConverter.writeRow(
88+
row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool);
89+
assert (bytesWritten == sizeRequirement);
90+
final long prefix = prefixComputer.apply(row);
91+
sorter.insertRecord(
92+
rowConversionBuffer,
93+
PlatformDependent.BYTE_ARRAY_OFFSET,
94+
sizeRequirement,
95+
prefix
96+
);
97+
}
98+
99+
@VisibleForTesting
100+
void spill() throws IOException {
101+
sorter.spill();
102+
}
103+
104+
private void cleanupResources() {
105+
sorter.freeMemory();
106+
}
107+
108+
@VisibleForTesting
109+
Iterator<InternalRow> sort() throws IOException {
75110
try {
76-
while (inputIterator.hasNext()) {
77-
final InternalRow row = inputIterator.next();
78-
final int sizeRequirement = rowConverter.getSizeRequirement(row);
79-
if (sizeRequirement > rowConversionBuffer.length) {
80-
rowConversionBuffer = new byte[sizeRequirement];
81-
} else {
82-
// Zero out the buffer that's used to hold the current row. This is necessary in order
83-
// to ensure that rows hash properly, since garbage data from the previous row could
84-
// otherwise end up as padding in this row. As a performance optimization, we only zero
85-
// out the portion of the buffer that we'll actually write to.
86-
Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0);
87-
}
88-
final int bytesWritten =
89-
rowConverter.writeRow(row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET);
90-
assert (bytesWritten == sizeRequirement);
91-
final long prefix = prefixComputer.apply(row);
92-
sorter.insertRecord(
93-
rowConversionBuffer,
94-
PlatformDependent.BYTE_ARRAY_OFFSET,
95-
sizeRequirement,
96-
prefix
97-
);
98-
}
99111
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
100112
return new AbstractScalaRowIterator() {
101113

@@ -113,7 +125,7 @@ public InternalRow next() {
113125
sortedIterator.loadNext();
114126
if (hasNext()) {
115127
row.pointTo(
116-
sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, schema);
128+
sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool);
117129
return row;
118130
} else {
119131
final byte[] rowDataCopy = new byte[sortedIterator.getRecordLength()];
@@ -125,14 +137,12 @@ public InternalRow next() {
125137
sortedIterator.getRecordLength()
126138
);
127139
row.backingArray = rowDataCopy;
128-
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema);
140+
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, objPool);
129141
sorter.freeMemory();
130142
return row;
131143
}
132144
} catch (IOException e) {
133-
// TODO: we need to ensure that files are cleaned properly after an exception,
134-
// so we need better cleanup methods than freeMemory().
135-
sorter.freeMemory();
145+
cleanupResources();
136146
// Scala iterators don't declare any checked exceptions, so we need to use this hack
137147
// to re-throw the exception:
138148
PlatformDependent.throwException(e);
@@ -141,30 +151,36 @@ public InternalRow next() {
141151
};
142152
};
143153
} catch (IOException e) {
144-
// TODO: we need to ensure that files are cleaned properly after an exception,
145-
// so we need better cleanup methods than freeMemory().
146-
sorter.freeMemory();
154+
cleanupResources();
147155
throw e;
148156
}
149157
}
150158

159+
160+
public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
161+
while (inputIterator.hasNext()) {
162+
insertRow(inputIterator.next());
163+
}
164+
return sort();
165+
}
166+
151167
private static final class RowComparator extends RecordComparator {
152-
private final StructType schema;
153168
private final Ordering<InternalRow> ordering;
154169
private final int numFields;
170+
private final ObjectPool objPool;
155171
private final UnsafeRow row1 = new UnsafeRow();
156172
private final UnsafeRow row2 = new UnsafeRow();
157173

158-
public RowComparator(Ordering<InternalRow> ordering, StructType schema) {
159-
this.schema = schema;
160-
this.numFields = schema.length();
174+
public RowComparator(Ordering<InternalRow> ordering, int numFields, ObjectPool objPool) {
175+
this.numFields = numFields;
161176
this.ordering = ordering;
177+
this.objPool = objPool;
162178
}
163179

164180
@Override
165181
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
166-
row1.pointTo(baseObj1, baseOff1, numFields, schema);
167-
row2.pointTo(baseObj2, baseOff2, numFields, schema);
182+
row1.pointTo(baseObj1, baseOff1, numFields, objPool);
183+
row2.pointTo(baseObj2, baseOff2, numFields, objPool);
168184
return ordering.compare(row1, row2);
169185
}
170186
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.types.StructType
2121
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
22-
import org.apache.spark.{SparkEnv, HashPartitioner}
2322
import org.apache.spark.annotation.DeveloperApi
2423
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2524
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -275,7 +274,7 @@ case class UnsafeExternalSort(
275274
val prefixComparator = new PrefixComparator {
276275
override def compare(prefix1: Long, prefix2: Long): Int = 0
277276
}
278-
// TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation.
277+
// TODO: do real prefix comparison. For dev/testing purposes, this is a dummy implementation.
279278
def prefixComputer(row: InternalRow): Long = 0
280279
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
281280
}

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
22+
import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, AttributeReference, SortOrder}
23+
import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
24+
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
2025
import org.scalatest.BeforeAndAfterAll
2126

22-
import org.apache.spark.sql.SQLConf
27+
import org.apache.spark.sql.{Row, SQLConf}
2328
import org.apache.spark.sql.catalyst.dsl.expressions._
2429
import org.apache.spark.sql.test.TestSQLContext
2530

@@ -54,4 +59,47 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
5459
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
5560
input.sortBy(t => (t._2, t._1)))
5661
}
62+
63+
test("sorting with object columns") {
64+
// TODO: larger input data
65+
val input = Seq(
66+
Row("Hello", Row(1)),
67+
Row("World", Row(2))
68+
)
69+
70+
val schema = StructType(
71+
StructField("a", StringType, nullable = false) ::
72+
StructField("b", StructType(StructField("b", IntegerType, nullable = false) :: Nil)) ::
73+
Nil
74+
)
75+
76+
// Hack so that we don't need to pass in / mock TaskContext, SparkEnv, etc. Ultimately it would
77+
// be better to not use this hack, but due to time constraints I have deferred this for
78+
// followup PRs.
79+
val sortResult = TestSQLContext.sparkContext.parallelize(input, 1).mapPartitions { iter =>
80+
val rows = iter.toSeq
81+
val sortOrder = SortOrder(BoundReference(0, StringType, nullable = false), Ascending)
82+
83+
val sorter = new UnsafeExternalRowSorter(
84+
schema,
85+
GenerateOrdering.generate(Seq(sortOrder), schema.toAttributes),
86+
new PrefixComparator {
87+
override def compare(prefix1: Long, prefix2: Long): Int = 0
88+
},
89+
x => 0L
90+
)
91+
92+
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
93+
94+
sorter.insertRow(toCatalyst(input.head).asInstanceOf[InternalRow])
95+
sorter.spill()
96+
input.tail.foreach { row =>
97+
sorter.insertRow(toCatalyst(row).asInstanceOf[InternalRow])
98+
}
99+
val sortedRowsIterator = sorter.sort()
100+
sortedRowsIterator.map(CatalystTypeConverters.convertToScala(_, schema).asInstanceOf[Row])
101+
}.collect()
102+
103+
assert(input.sortBy(t => t.getString(0)) === sortResult)
104+
}
57105
}

0 commit comments

Comments
 (0)