@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources
1919
2020import java .util .{Date , UUID }
2121
22+ import scala .collection .JavaConverters ._
2223import scala .collection .mutable
2324
2425import org .apache .hadoop .conf .Configuration
@@ -68,7 +69,8 @@ object FileFormatWriter extends Logging {
6869 val bucketSpec : Option [BucketSpec ],
6970 val path : String ,
7071 val customPartitionLocations : Map [TablePartitionSpec , String ],
71- val maxRecordsPerFile : Long )
72+ val maxRecordsPerFile : Long ,
73+ val orderingInPartition : Seq [SortOrder ])
7274 extends Serializable {
7375
7476 assert(AttributeSet (allColumns) == AttributeSet (partitionColumns ++ dataColumns),
@@ -125,7 +127,8 @@ object FileFormatWriter extends Logging {
125127 path = outputSpec.outputPath,
126128 customPartitionLocations = outputSpec.customPartitionLocations,
127129 maxRecordsPerFile = options.get(" maxRecordsPerFile" ).map(_.toLong)
128- .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
130+ .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
131+ orderingInPartition = queryExecution.executedPlan.outputOrdering
129132 )
130133
131134 SQLExecution .withNewExecutionId(sparkSession, queryExecution) {
@@ -368,17 +371,58 @@ object FileFormatWriter extends Logging {
368371 }
369372
370373 override def execute (iter : Iterator [InternalRow ]): Set [String ] = {
371- // We should first sort by partition columns, then bucket id, and finally sorting columns.
374+ // If there is sort ordering in the data, we need to keep the ordering.
375+ val orderingExpressions : Seq [Expression ] = if (description.orderingInPartition.isEmpty) {
376+ Nil
377+ } else {
378+ description.orderingInPartition.map(_.child)
379+ }
380+
381+ // We should first sort by partition columns, then bucket id, then sort ordering in the data,
382+ // and finally sorting columns.
372383 val sortingExpressions : Seq [Expression ] =
373- description.partitionColumns ++ bucketIdExpression ++ sortColumns
384+ description.partitionColumns ++ bucketIdExpression ++ orderingExpressions ++ sortColumns
374385 val getSortingKey = UnsafeProjection .create(sortingExpressions, description.allColumns)
375386
376- val sortingKeySchema = StructType (sortingExpressions.map {
377- case a : Attribute => StructField (a.name, a.dataType, a.nullable)
378- // The sorting expressions are all `Attribute` except bucket id.
379- case _ => StructField (" bucketId" , IntegerType , nullable = false )
387+ val bucketIdExprIndex =
388+ sortingExpressions.length - sortColumns.length - orderingExpressions.length - 1
389+
390+ val sortingKeySchema = StructType (sortingExpressions.zipWithIndex.map { case (e, index) =>
391+ e match {
392+ case a : Attribute => StructField (a.name, a.dataType, a.nullable)
393+ // The sorting expressions are all `Attribute` except bucket id and
394+ // sorting order's children expressions.
395+ case _ if index == bucketIdExprIndex =>
396+ StructField (" bucketId" , IntegerType , nullable = false )
397+ case _ if index > bucketIdExprIndex =>
398+ StructField (s " _sortOrder_ $index" , e.dataType, e.nullable)
399+ }
380400 })
381401
402+ val beginSortingExpr =
403+ sortingExpressions.length - sortColumns.length - orderingExpressions.length
404+ val recordSortingOrder =
405+ if (description.orderingInPartition.isEmpty) {
406+ null
407+ } else {
408+ sortingExpressions.zipWithIndex.map { case (field, ordinal) =>
409+ if (ordinal < beginSortingExpr ||
410+ ordinal > beginSortingExpr + orderingExpressions.length) {
411+ // For partition column, bucket id and sort by columns, we sort by ascending.
412+ SortOrder (BoundReference (ordinal, field.dataType, nullable = true ), Ascending )
413+ } else {
414+ // For the sort ordering of data, we need to keep its sort direction and
415+ // null ordering.
416+ val direction =
417+ description.orderingInPartition(ordinal - beginSortingExpr).direction
418+ val nullOrdering =
419+ description.orderingInPartition(ordinal - beginSortingExpr).nullOrdering
420+ SortOrder (BoundReference (ordinal, field.dataType, nullable = true ),
421+ direction, nullOrdering)
422+ }
423+ }.asJava
424+ }
425+
382426 // Returns the data columns to be written given an input row
383427 val getOutputRow = UnsafeProjection .create(
384428 description.dataColumns, description.allColumns)
@@ -395,20 +439,25 @@ object FileFormatWriter extends Logging {
395439 SparkEnv .get.serializerManager,
396440 TaskContext .get().taskMemoryManager().pageSizeBytes,
397441 SparkEnv .get.conf.getLong(" spark.shuffle.spill.numElementsForceSpillThreshold" ,
398- UnsafeExternalSorter .DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD ))
442+ UnsafeExternalSorter .DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD ),
443+ null ,
444+ recordSortingOrder)
399445
400446 while (iter.hasNext) {
401447 val currentRow = iter.next()
402448 sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
403449 }
404450
405- val getBucketingKey : InternalRow => InternalRow = if (sortColumns.isEmpty) {
406- identity
407- } else {
408- UnsafeProjection .create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
409- case (expr, ordinal) => BoundReference (ordinal, expr.dataType, expr.nullable)
410- })
411- }
451+ val getBucketingKey : InternalRow => InternalRow =
452+ if (sortColumns.isEmpty && orderingExpressions.isEmpty) {
453+ identity
454+ } else {
455+ val bucketingKeyExprs =
456+ sortingExpressions.dropRight(sortColumns.length + orderingExpressions.length)
457+ UnsafeProjection .create(bucketingKeyExprs.zipWithIndex.map {
458+ case (expr, ordinal) => BoundReference (ordinal, expr.dataType, expr.nullable)
459+ })
460+ }
412461
413462 val sortedIterator = sorter.sortedIterator()
414463
0 commit comments