Skip to content

Commit 3c040b6

Browse files
committed
Keep sort order of rows after external sorter when writing.
1 parent 40a4cfc commit 3c040b6

File tree

4 files changed

+172
-24
lines changed

4 files changed

+172
-24
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@
1919

2020
import javax.annotation.Nullable;
2121
import java.io.IOException;
22+
import java.util.List;
23+
24+
import scala.collection.JavaConverters;
25+
import scala.collection.Seq;
2226

2327
import com.google.common.annotations.VisibleForTesting;
2428

2529
import org.apache.spark.SparkEnv;
2630
import org.apache.spark.TaskContext;
2731
import org.apache.spark.memory.TaskMemoryManager;
2832
import org.apache.spark.serializer.SerializerManager;
33+
import org.apache.spark.sql.catalyst.expressions.SortOrder;
2934
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
3035
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
3136
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -58,7 +63,7 @@ public UnsafeKVExternalSorter(
5863
long pageSizeBytes,
5964
long numElementsForSpillThreshold) throws IOException {
6065
this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes,
61-
numElementsForSpillThreshold, null);
66+
numElementsForSpillThreshold, null, null);
6267
}
6368

6469
public UnsafeKVExternalSorter(
@@ -69,14 +74,34 @@ public UnsafeKVExternalSorter(
6974
long pageSizeBytes,
7075
long numElementsForSpillThreshold,
7176
@Nullable BytesToBytesMap map) throws IOException {
77+
this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes,
78+
numElementsForSpillThreshold, map, null);
79+
}
80+
81+
public UnsafeKVExternalSorter(
82+
StructType keySchema,
83+
StructType valueSchema,
84+
BlockManager blockManager,
85+
SerializerManager serializerManager,
86+
long pageSizeBytes,
87+
long numElementsForSpillThreshold,
88+
@Nullable BytesToBytesMap map,
89+
@Nullable List<SortOrder> ordering) throws IOException {
7290
this.keySchema = keySchema;
7391
this.valueSchema = valueSchema;
7492
final TaskContext taskContext = TaskContext.get();
7593

7694
prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
7795
PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
78-
BaseOrdering ordering = GenerateOrdering.create(keySchema);
79-
KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
96+
KVComparator recordComparator = null;
97+
if (ordering == null) {
98+
recordComparator = new KVComparator(GenerateOrdering.create(keySchema), keySchema.length());
99+
} else {
100+
Seq<SortOrder> orderingSeq =
101+
JavaConverters.collectionAsScalaIterableConverter(ordering).asScala().toSeq();
102+
recordComparator = new KVComparator((BaseOrdering)GenerateOrdering.generate(orderingSeq),
103+
ordering.size());
104+
}
80105
boolean canUseRadixSort = keySchema.length() == 1 &&
81106
SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0));
82107

@@ -137,7 +162,7 @@ public UnsafeKVExternalSorter(
137162
blockManager,
138163
serializerManager,
139164
taskContext,
140-
new KVComparator(ordering, keySchema.length()),
165+
recordComparator,
141166
prefixComparator,
142167
SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize",
143168
UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE),

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import java.util.{Date, UUID}
2121

22+
import scala.collection.JavaConverters._
2223
import scala.collection.mutable
2324

2425
import org.apache.hadoop.conf.Configuration
@@ -68,7 +69,8 @@ object FileFormatWriter extends Logging {
6869
val bucketSpec: Option[BucketSpec],
6970
val path: String,
7071
val customPartitionLocations: Map[TablePartitionSpec, String],
71-
val maxRecordsPerFile: Long)
72+
val maxRecordsPerFile: Long,
73+
val orderingInPartition: Seq[SortOrder])
7274
extends Serializable {
7375

7476
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
@@ -125,7 +127,8 @@ object FileFormatWriter extends Logging {
125127
path = outputSpec.outputPath,
126128
customPartitionLocations = outputSpec.customPartitionLocations,
127129
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
128-
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
130+
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
131+
orderingInPartition = queryExecution.executedPlan.outputOrdering
129132
)
130133

131134
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
@@ -368,17 +371,58 @@ object FileFormatWriter extends Logging {
368371
}
369372

370373
override def execute(iter: Iterator[InternalRow]): Set[String] = {
371-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
374+
// If there is sort ordering in the data, we need to keep the ordering.
375+
val orderingExpressions: Seq[Expression] = if (description.orderingInPartition.isEmpty) {
376+
Nil
377+
} else {
378+
description.orderingInPartition.map(_.child)
379+
}
380+
381+
// We should first sort by partition columns, then bucket id, then sort ordering in the data,
382+
// and finally sorting columns.
372383
val sortingExpressions: Seq[Expression] =
373-
description.partitionColumns ++ bucketIdExpression ++ sortColumns
384+
description.partitionColumns ++ bucketIdExpression ++ orderingExpressions ++ sortColumns
374385
val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)
375386

376-
val sortingKeySchema = StructType(sortingExpressions.map {
377-
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
378-
// The sorting expressions are all `Attribute` except bucket id.
379-
case _ => StructField("bucketId", IntegerType, nullable = false)
387+
val bucketIdExprIndex =
388+
sortingExpressions.length - sortColumns.length - orderingExpressions.length - 1
389+
390+
val sortingKeySchema = StructType(sortingExpressions.zipWithIndex.map { case (e, index) =>
391+
e match {
392+
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
393+
// The sorting expressions are all `Attribute` except bucket id and
394+
// sorting order's children expressions.
395+
case _ if index == bucketIdExprIndex =>
396+
StructField("bucketId", IntegerType, nullable = false)
397+
case _ if index > bucketIdExprIndex =>
398+
StructField(s"_sortOrder_$index", e.dataType, e.nullable)
399+
}
380400
})
381401

402+
val beginSortingExpr =
403+
sortingExpressions.length - sortColumns.length - orderingExpressions.length
404+
val recordSortingOrder =
405+
if (description.orderingInPartition.isEmpty) {
406+
null
407+
} else {
408+
sortingExpressions.zipWithIndex.map { case (field, ordinal) =>
409+
if (ordinal < beginSortingExpr ||
410+
ordinal > beginSortingExpr + orderingExpressions.length) {
411+
// For partition column, bucket id and sort by columns, we sort by ascending.
412+
SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending)
413+
} else {
414+
// For the sort ordering of data, we need to keep its sort direction and
415+
// null ordering.
416+
val direction =
417+
description.orderingInPartition(ordinal - beginSortingExpr).direction
418+
val nullOrdering =
419+
description.orderingInPartition(ordinal - beginSortingExpr).nullOrdering
420+
SortOrder(BoundReference(ordinal, field.dataType, nullable = true),
421+
direction, nullOrdering)
422+
}
423+
}.asJava
424+
}
425+
382426
// Returns the data columns to be written given an input row
383427
val getOutputRow = UnsafeProjection.create(
384428
description.dataColumns, description.allColumns)
@@ -395,20 +439,25 @@ object FileFormatWriter extends Logging {
395439
SparkEnv.get.serializerManager,
396440
TaskContext.get().taskMemoryManager().pageSizeBytes,
397441
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
398-
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
442+
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
443+
null,
444+
recordSortingOrder)
399445

400446
while (iter.hasNext) {
401447
val currentRow = iter.next()
402448
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
403449
}
404450

405-
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
406-
identity
407-
} else {
408-
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
409-
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
410-
})
411-
}
451+
val getBucketingKey: InternalRow => InternalRow =
452+
if (sortColumns.isEmpty && orderingExpressions.isEmpty) {
453+
identity
454+
} else {
455+
val bucketingKeyExprs =
456+
sortingExpressions.dropRight(sortColumns.length + orderingExpressions.length)
457+
UnsafeProjection.create(bucketingKeyExprs.zipWithIndex.map {
458+
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
459+
})
460+
}
412461

413462
val sortedIterator = sorter.sortedIterator()
414463

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ package org.apache.spark.sql.execution
1919

2020
import java.util.Properties
2121

22+
import scala.collection.JavaConverters._
2223
import scala.util.Random
2324

2425
import org.apache.spark._
2526
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
2627
import org.apache.spark.sql.{RandomDataGenerator, Row}
2728
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
28-
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow}
29+
import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Descending, InterpretedOrdering, SortOrder, UnsafeProjection, UnsafeRow}
2930
import org.apache.spark.sql.test.SharedSQLContext
3031
import org.apache.spark.sql.types._
3132
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
@@ -110,7 +111,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
110111
valueSchema: StructType,
111112
inputData: Seq[(InternalRow, InternalRow)],
112113
pageSize: Long,
113-
spill: Boolean): Unit = {
114+
spill: Boolean,
115+
sortOrdering: java.util.List[SortOrder] = null): Unit = {
114116
val memoryManager =
115117
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"))
116118
val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
@@ -125,7 +127,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
125127

126128
val sorter = new UnsafeKVExternalSorter(
127129
keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager,
128-
pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)
130+
pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD,
131+
null, sortOrdering)
129132

130133
// Insert the keys and values into the sorter
131134
inputData.foreach { case (k, v) =>
@@ -145,7 +148,11 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
145148
}
146149
sorter.cleanupResources()
147150

148-
val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType))
151+
val keyOrdering = if (sortOrdering == null) {
152+
InterpretedOrdering.forSchema(keySchema.map(_.dataType))
153+
} else {
154+
new InterpretedOrdering(sortOrdering.asScala)
155+
}
149156
val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType))
150157
val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
151158
override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
@@ -204,4 +211,41 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
204211
spill = true
205212
)
206213
}
214+
215+
test("kv sorting with records that exceed page size: with specified sort order") {
216+
val pageSize = 128
217+
218+
val keySchema = StructType(StructField("a", BinaryType) :: StructField("b", BinaryType) :: Nil)
219+
val valueSchema = StructType(StructField("c", BinaryType) :: Nil)
220+
val keyExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
221+
val valueExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema)
222+
val keyConverter = UnsafeProjection.create(keySchema)
223+
val valueConverter = UnsafeProjection.create(valueSchema)
224+
225+
val rand = new Random()
226+
val inputData = Seq.fill(1024) {
227+
val kBytes1 = new Array[Byte](rand.nextInt(pageSize))
228+
val kBytes2 = new Array[Byte](rand.nextInt(pageSize))
229+
val vBytes = new Array[Byte](rand.nextInt(pageSize))
230+
rand.nextBytes(kBytes1)
231+
rand.nextBytes(kBytes2)
232+
rand.nextBytes(vBytes)
233+
val k =
234+
keyConverter(keyExternalConverter.apply(Row(kBytes1, kBytes2)).asInstanceOf[InternalRow])
235+
val v = valueConverter(valueExternalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow])
236+
(k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
237+
}
238+
239+
val sortOrder = SortOrder(BoundReference(0, BinaryType, nullable = true), Ascending) ::
240+
SortOrder(BoundReference(1, BinaryType, nullable = true), Descending) :: Nil
241+
242+
testKVSorter(
243+
keySchema,
244+
valueSchema,
245+
inputData,
246+
pageSize,
247+
spill = true,
248+
sortOrder.asJava
249+
)
250+
}
207251
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
487487
}
488488
}
489489

490+
test("SPARK-19352: Keep sort order of rows after external sorter when writing") {
491+
spark.stop()
492+
// Explicitly set memory configuration to force `UnsafeKVExternalSorter` to spill to files
493+
// when inserting data.
494+
val newSpark = SparkSession.builder()
495+
.master("local")
496+
.appName("test")
497+
.config("spark.buffer.pageSize", "16b")
498+
.config("spark.testing.memory", "1400")
499+
.config("spark.memory.fraction", "0.1")
500+
.config("spark.shuffle.sort.initialBufferSize", "2")
501+
.config("spark.memory.offHeap.enabled", "false")
502+
.getOrCreate()
503+
withTempPath { path =>
504+
val tempDir = path.getCanonicalPath
505+
val df = newSpark.range(100)
506+
.select($"id", explode(array(col("id") + 1, col("id") + 2, col("id") + 3)).as("value"))
507+
.repartition($"id")
508+
.sortWithinPartitions($"value".desc).toDF()
509+
510+
df.write
511+
.partitionBy("id")
512+
.parquet(tempDir)
513+
514+
val dfReadIn = newSpark.read.parquet(tempDir).select("id", "value")
515+
checkAnswer(df.filter("id = 65"), dfReadIn.filter("id = 65"))
516+
}
517+
newSpark.stop()
518+
}
519+
490520
// Helpers for checking the arguments passed to the FileFormat.
491521

492522
protected val checkPartitionSchema =

0 commit comments

Comments
 (0)