diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 95e1a159ef84..257866ace70e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -156,7 +156,12 @@ case class CreateDataSourceTableAsSelectCommand( override def requiredOrdering: Seq[SortOrder] = { val options = table.storage.properties - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options) + val originSortedColumns = query.outputOrdering.flatMap(_.child match { + case attr: Attribute => Some(attr) + case _ => None + }) + V1WritesUtils.getSortOrder(originSortedColumns, outputColumns, + partitionColumns, table.bucketSpec, options) } override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 41b55a3b6e93..41356a6d4f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -74,9 +74,14 @@ case class InsertIntoHadoopFsRelationCommand( staticPartitions.size < partitionColumns.length } - override def requiredOrdering: Seq[SortOrder] = - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, bucketSpec, options, - staticPartitions.size) + override def requiredOrdering: Seq[SortOrder] = { + val originSortedColumns = query.outputOrdering.flatMap(_.child match { + case attr: Attribute => Some(attr) + case _ => None + }) + V1WritesUtils.getSortOrder(originSortedColumns, outputColumns, + partitionColumns, bucketSpec, options, staticPartitions.size) + } override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index d082b95739cb..5a84b5bac177 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -158,6 +158,7 @@ object V1WritesUtils { } def getSortOrder( + originSortColumns: Seq[Attribute], outputColumns: Seq[Attribute], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -166,6 +167,7 @@ object V1WritesUtils { require(partitionColumns.size >= numStaticPartitionCols) val partitionSet = AttributeSet(partitionColumns) + val originSortSet = outputColumns.filter(AttributeSet(originSortColumns).contains) val dataColumns = outputColumns.filterNot(partitionSet.contains) val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options) val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns) @@ -178,7 +180,15 @@ object V1WritesUtils { } else { // We should first sort by dynamic partition columns, then bucket id, and finally sorting // columns. - (dynamicPartitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns) + val sortOrder = (dynamicPartitionColumns ++ + writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns) + val exprIdSet = sortOrder.flatMap({ + case a: Attribute => Some(a.exprId) + case _ => None + }).toSet + val residualSort = originSortSet.filterNot(s => (sortOrder.contains() + || exprIdSet.contains(s.exprId))) + (sortOrder ++ residualSort) .map(SortOrder(_, Ascending)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 9085eff69dc1..6080255868cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -220,6 +220,23 @@ class PartitionedWriteSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-40885: V1 write uses the sort with partitionBy operator") { + withTempPath { f => + Seq((20, 30, "partition"), (15, 20, "partition"), + (30, 70, "partition"), (18, 40, "partition")) + .toDF("id", "sort_col", "p") + .repartition(1) + .sortWithinPartitions("p", "sort_col") + .write + .partitionBy("p") + .parquet(f.getAbsolutePath) + val sortColList = spark.read.parquet(f.getAbsolutePath) + .map(_.getInt(1)).collect().toList + val expectList = List(20, 30, 40, 70) + assert(sortColList == expectList) + } + } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ce3207750277..f71fbe4d70b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -54,7 +54,12 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with V1WritesHiveUtils override def requiredOrdering: Seq[SortOrder] = { val options = getOptionsWithHiveBucketWrite(tableDesc.bucketSpec) - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, tableDesc.bucketSpec, options) + val originSortedColumns = query.outputOrdering.flatMap(_.child match { + case attr: Attribute => Some(attr) + case _ => None + }) + V1WritesUtils.getSortOrder(originSortedColumns, outputColumns, + partitionColumns, tableDesc.bucketSpec, options) } override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8c3aa0a80c1b..605e1209d53e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -82,7 +82,12 @@ case class InsertIntoHiveTable( override def requiredOrdering: Seq[SortOrder] = { val options = getOptionsWithHiveBucketWrite(table.bucketSpec) - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options) + val originSortedColumns = query.outputOrdering.flatMap(_.child match { + case attr: Attribute => Some(attr) + case _ => None + }) + V1WritesUtils.getSortOrder(originSortedColumns, outputColumns, + partitionColumns, table.bucketSpec, options) } /**