Skip to content

Commit 4c37ba6

Browse files
committed
Add tests for sorting on all primitive types.
1 parent 6890863 commit 4c37ba6

File tree

5 files changed

+171
-125
lines changed

5 files changed

+171
-125
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ final class UnsafeSorterSpillWriter {
3737

3838
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
3939
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
40-
// data through a byte array. This array does not need to be large enough to hold a single
41-
// record;
40+
// data through a byte array.
4241
private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
4342

4443
private final File file;
@@ -115,10 +114,6 @@ public void close() throws IOException {
115114
writeBuffer = null;
116115
}
117116

118-
public long numberOfSpilledBytes() {
119-
return file.length();
120-
}
121-
122117
public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
123118
return new UnsafeSorterSpillReader(blockManager, file, blockId);
124119
}

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.io.IOException;
2121
import java.util.Arrays;
2222

23-
import scala.Function1;
2423
import scala.collection.Iterator;
2524
import scala.math.Ordering;
2625

@@ -47,17 +46,20 @@ final class UnsafeExternalRowSorter {
4746

4847
private final StructType schema;
4948
private final UnsafeRowConverter rowConverter;
50-
private final Function1<InternalRow, Long> prefixComputer;
49+
private final PrefixComputer prefixComputer;
5150
private final ObjectPool objPool = new ObjectPool(128);
5251
private final UnsafeExternalSorter sorter;
5352
private byte[] rowConversionBuffer = new byte[1024 * 8];
5453

54+
public static abstract class PrefixComputer {
55+
abstract long computePrefix(InternalRow row);
56+
}
57+
5558
public UnsafeExternalRowSorter(
5659
StructType schema,
5760
Ordering<InternalRow> ordering,
5861
PrefixComparator prefixComparator,
59-
// TODO: if possible, avoid this boxing of the return value
60-
Function1<InternalRow, Long> prefixComputer) throws IOException {
62+
PrefixComputer prefixComputer) throws IOException {
6163
this.schema = schema;
6264
this.rowConverter = new UnsafeRowConverter(schema);
6365
this.prefixComputer = prefixComputer;
@@ -90,7 +92,7 @@ void insertRow(InternalRow row) throws IOException {
9092
final int bytesWritten = rowConverter.writeRow(
9193
row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool);
9294
assert (bytesWritten == sizeRequirement);
93-
final long prefix = prefixComputer.apply(row);
95+
final long prefix = prefixComputer.computePrefix(row);
9496
sorter.insertRecord(
9597
rowConversionBuffer,
9698
PlatformDependent.BYTE_ARRAY_OFFSET,

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.Row
2828
import org.apache.spark.sql.catalyst.plans.physical._
2929
import org.apache.spark.util.collection.ExternalSorter
30-
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
3130
import org.apache.spark.util.{CompletionIterator, MutablePair}
3231
import org.apache.spark.{HashPartitioner, SparkEnv}
3332

@@ -271,11 +270,13 @@ case class UnsafeExternalSort(
271270
assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
272271
def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
273272
val ordering = newOrdering(sortOrder, child.output)
274-
val prefixComparator = new PrefixComparator {
275-
override def compare(prefix1: Long, prefix2: Long): Int = 0
273+
val prefixComparator = SortPrefixUtils.getPrefixComparator(sortOrder.head)
274+
val prefixComputer = {
275+
val prefixComputer = SortPrefixUtils.getPrefixComputer(sortOrder.head)
276+
new UnsafeExternalRowSorter.PrefixComputer {
277+
override def computePrefix(row: InternalRow): Long = prefixComputer(row)
278+
}
276279
}
277-
// TODO: do real prefix comparison. For dev/testing purposes, this is a dummy implementation.
278-
def prefixComputer(row: InternalRow): Long = 0
279280
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
280281
}
281282
child.execute().mapPartitions(doSort, preservesPartitioning = true)

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 132 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import scala.language.implicitConversions
21-
import scala.reflect.runtime.universe.TypeTag
22-
import scala.util.control.NonFatal
23-
2420
import org.apache.spark.SparkFunSuite
25-
2621
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2722
import org.apache.spark.sql.catalyst.expressions.BoundReference
2823
import org.apache.spark.sql.catalyst.util._
29-
3024
import org.apache.spark.sql.test.TestSQLContext
31-
import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame}
25+
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
26+
27+
import scala.language.implicitConversions
28+
import scala.reflect.runtime.universe.TypeTag
29+
import scala.util.control.NonFatal
3230

3331
/**
3432
* Base class for writing tests for individual physical operators. For an example of how this
@@ -77,13 +75,93 @@ class SparkPlanTest extends SparkFunSuite {
7775
case None =>
7876
}
7977
}
78+
79+
/**
80+
* Runs the plan and makes sure the answer matches the result produced by a reference plan.
81+
* @param input the input data to be used.
82+
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
83+
* the physical operator that's being tested.
84+
* @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
85+
* instantiate a reference implementation of the physical operator
86+
* that's being tested. The result of executing this plan will be
87+
* treated as the source-of-truth for the test.
88+
*/
89+
protected def checkAnswer(
90+
input: DataFrame,
91+
planFunction: SparkPlan => SparkPlan,
92+
expectedPlanFunction: SparkPlan => SparkPlan): Unit = {
93+
SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction) match {
94+
case Some(errorMessage) => fail(errorMessage)
95+
case None =>
96+
}
97+
}
8098
}
8199

82100
/**
83101
* Helper methods for writing tests of individual physical operators.
84102
*/
85103
object SparkPlanTest {
86104

105+
/**
106+
* Runs the plan and makes sure the answer matches the result produced by a reference plan.
107+
* @param input the input data to be used.
108+
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
109+
* the physical operator that's being tested.
110+
* @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
111+
* instantiate a reference implementation of the physical operator
112+
* that's being tested. The result of executing this plan will be
113+
* treated as the source-of-truth for the test.
114+
*/
115+
def checkAnswer(
116+
input: DataFrame,
117+
planFunction: SparkPlan => SparkPlan,
118+
expectedPlanFunction: SparkPlan => SparkPlan): Option[String] = {
119+
120+
val outputPlan = planFunction(input.queryExecution.sparkPlan)
121+
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
122+
123+
val expectedAnswer: Seq[Row] = try {
124+
executePlan(input, expectedOutputPlan)
125+
} catch {
126+
case NonFatal(e) =>
127+
val errorMessage =
128+
s"""
129+
| Exception thrown while executing Spark plan to calculate expected answer:
130+
| $expectedOutputPlan
131+
| == Exception ==
132+
| $e
133+
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
134+
""".stripMargin
135+
return Some(errorMessage)
136+
}
137+
138+
val actualAnswer: Seq[Row] = try {
139+
executePlan(input, outputPlan)
140+
} catch {
141+
case NonFatal(e) =>
142+
val errorMessage =
143+
s"""
144+
| Exception thrown while executing Spark plan:
145+
| $outputPlan
146+
| == Exception ==
147+
| $e
148+
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
149+
""".stripMargin
150+
return Some(errorMessage)
151+
}
152+
153+
compareAnswers(actualAnswer, expectedAnswer).map { errorMessage =>
154+
s"""
155+
| Results do not match.
156+
| Actual result Spark plan:
157+
| $outputPlan
158+
| Expected result Spark plan:
159+
| $expectedOutputPlan
160+
| $errorMessage
161+
""".stripMargin
162+
}
163+
}
164+
87165
/**
88166
* Runs the plan and makes sure the answer matches the expected result.
89167
* @param input the input data to be used.
@@ -98,22 +176,33 @@ object SparkPlanTest {
98176

99177
val outputPlan = planFunction(input.queryExecution.sparkPlan)
100178

101-
// A very simple resolver to make writing tests easier. In contrast to the real resolver
102-
// this is always case sensitive and does not try to handle scoping or complex type resolution.
103-
val resolvedPlan = outputPlan transform {
104-
case plan: SparkPlan =>
105-
val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
106-
case (a, i) =>
107-
(a.name, BoundReference(i, a.dataType, a.nullable))
108-
}.toMap
179+
val sparkAnswer: Seq[Row] = try {
180+
executePlan(input, outputPlan)
181+
} catch {
182+
case NonFatal(e) =>
183+
val errorMessage =
184+
s"""
185+
| Exception thrown while executing Spark plan:
186+
| $outputPlan
187+
| == Exception ==
188+
| $e
189+
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
190+
""".stripMargin
191+
return Some(errorMessage)
192+
}
109193

110-
plan.transformExpressions {
111-
case UnresolvedAttribute(Seq(u)) =>
112-
inputMap.getOrElse(u,
113-
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
114-
}
194+
compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage =>
195+
s"""
196+
| Results do not match for Spark plan:
197+
| $outputPlan
198+
| $errorMessage
199+
""".stripMargin
115200
}
201+
}
116202

203+
private def compareAnswers(
204+
sparkAnswer: Seq[Row],
205+
expectedAnswer: Seq[Row]): Option[String] = {
117206
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
118207
// Converts data to types that we can do equality comparison using Scala collections.
119208
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -130,38 +219,39 @@ object SparkPlanTest {
130219
}
131220
converted.sortBy(_.toString())
132221
}
133-
134-
val sparkAnswer: Seq[Row] = try {
135-
resolvedPlan.executeCollect().toSeq
136-
} catch {
137-
case NonFatal(e) =>
138-
val errorMessage =
139-
s"""
140-
| Exception thrown while executing Spark plan:
141-
| $outputPlan
142-
| == Exception ==
143-
| $e
144-
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
145-
""".stripMargin
146-
return Some(errorMessage)
147-
}
148-
149222
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
150223
val errorMessage =
151224
s"""
152-
| Results do not match for Spark plan:
153-
| $outputPlan
154225
| == Results ==
155226
| ${sideBySide(
156-
s"== Correct Answer - ${expectedAnswer.size} ==" +:
227+
s"== Expected Answer - ${expectedAnswer.size} ==" +:
157228
prepareAnswer(expectedAnswer).map(_.toString()),
158-
s"== Spark Answer - ${sparkAnswer.size} ==" +:
229+
s"== Actual Answer - ${sparkAnswer.size} ==" +:
159230
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
160231
""".stripMargin
161-
return Some(errorMessage)
232+
Some(errorMessage)
233+
} else {
234+
None
162235
}
236+
}
163237

164-
None
238+
private def executePlan(input: DataFrame, outputPlan: SparkPlan): Seq[Row] = {
239+
// A very simple resolver to make writing tests easier. In contrast to the real resolver
240+
// this is always case sensitive and does not try to handle scoping or complex type resolution.
241+
val resolvedPlan = outputPlan transform {
242+
case plan: SparkPlan =>
243+
val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
244+
case (a, i) =>
245+
(a.name, BoundReference(i, a.dataType, a.nullable))
246+
}.toMap
247+
248+
plan.transformExpressions {
249+
case UnresolvedAttribute(Seq(u)) =>
250+
inputMap.getOrElse(u,
251+
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
252+
}
253+
}
254+
resolvedPlan.executeCollect().toSeq
165255
}
166256
}
167257

0 commit comments

Comments
 (0)